ai-robot-core/ai-service/tests/test_metadata_auto_inferenc...

444 lines
14 KiB
Python
Raw Normal View History

"""
Unit tests for MetadataAutoInferenceService.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from dataclasses import dataclass
from app.services.metadata_auto_inference_service import (
AutoInferenceResult,
InferenceFieldContext,
MetadataAutoInferenceService,
)
@dataclass
class MockFieldDefinition:
"""Mock field definition for testing"""
field_key: str
label: str
type: str
required: bool
options: list[str] | None = None
class TestInferenceFieldContext:
"""Test InferenceFieldContext dataclass."""
def test_creation(self):
"""Test creating InferenceFieldContext."""
ctx = InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
)
assert ctx.field_key == "grade"
assert ctx.label == "年级"
assert ctx.type == "enum"
assert ctx.required is True
assert ctx.options == ["初一", "初二", "初三"]
def test_default_values(self):
"""Test default values."""
ctx = InferenceFieldContext(
field_key="test",
label="Test",
type="text",
required=False,
)
assert ctx.options is None
assert ctx.description is None
class TestAutoInferenceResult:
"""Test AutoInferenceResult dataclass."""
def test_success_result(self):
"""Test successful inference result."""
result = AutoInferenceResult(
inferred_metadata={"grade": "初一", "subject": "数学"},
confidence_scores={"grade": 0.95, "subject": 0.85},
raw_response='{"inferred_metadata": {...}}',
success=True,
)
assert result.success is True
assert result.error_message is None
assert result.inferred_metadata["grade"] == "初一"
def test_failure_result(self):
"""Test failed inference result."""
result = AutoInferenceResult(
inferred_metadata={},
confidence_scores={},
raw_response="",
success=False,
error_message="JSON parse error",
)
assert result.success is False
assert result.error_message == "JSON parse error"
class TestMetadataAutoInferenceService:
"""Test MetadataAutoInferenceService functionality."""
@pytest.fixture
def mock_session(self):
"""Create mock session."""
return AsyncMock()
@pytest.fixture
def service(self, mock_session):
"""Create service instance."""
return MetadataAutoInferenceService(mock_session)
def test_build_field_contexts(self, service):
"""Test building field contexts from definitions."""
fields = [
MockFieldDefinition(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
MockFieldDefinition(
field_key="subject",
label="学科",
type="enum",
required=True,
options=["语文", "数学", "英语"],
),
]
contexts = service._build_field_contexts(fields)
assert len(contexts) == 2
assert contexts[0].field_key == "grade"
assert contexts[0].options == ["初一", "初二", "初三"]
assert contexts[1].field_key == "subject"
def test_build_user_prompt(self, service):
"""Test building user prompt."""
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
]
prompt = service._build_user_prompt(
content="这是一道初一数学题",
field_contexts=field_contexts,
)
assert "年级" in prompt
assert "初一, 初二, 初三" in prompt
assert "这是一道初一数学题" in prompt
def test_build_user_prompt_with_existing_metadata(self, service):
"""Test building user prompt with existing metadata."""
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
]
prompt = service._build_user_prompt(
content="这是一道数学题",
field_contexts=field_contexts,
existing_metadata={"grade": "初二"},
)
assert "已有值: 初二" in prompt
def test_extract_json_from_plain_json(self, service):
"""Test extracting JSON from plain JSON response."""
json_str = '{"inferred_metadata": {"grade": "初一"}, "confidence_scores": {"grade": 0.95}}'
result = service._extract_json(json_str)
assert result == json_str
def test_extract_json_from_markdown(self, service):
"""Test extracting JSON from markdown code block."""
markdown = """Here is the result:
```json
{"inferred_metadata": {"grade": "初一"}, "confidence_scores": {"grade": 0.95}}
```
"""
result = service._extract_json(markdown)
assert "inferred_metadata" in result
assert "grade" in result
def test_parse_llm_response_valid(self, service):
"""Test parsing valid LLM response."""
response = json.dumps({
"inferred_metadata": {
"grade": "初一",
"subject": "数学",
},
"confidence_scores": {
"grade": 0.95,
"subject": 0.85,
}
})
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
InferenceFieldContext(
field_key="subject",
label="学科",
type="enum",
required=True,
options=["语文", "数学", "英语"],
),
]
result = service._parse_llm_response(response, field_contexts)
assert result.success is True
assert result.inferred_metadata["grade"] == "初一"
assert result.inferred_metadata["subject"] == "数学"
assert result.confidence_scores["grade"] == 0.95
def test_parse_llm_response_invalid_option(self, service):
"""Test parsing response with invalid enum option."""
response = json.dumps({
"inferred_metadata": {
"grade": "高一", # Not in options
},
"confidence_scores": {
"grade": 0.90,
}
})
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
]
result = service._parse_llm_response(response, field_contexts)
assert result.success is True
assert "grade" not in result.inferred_metadata
def test_parse_llm_response_invalid_json(self, service):
"""Test parsing invalid JSON response."""
response = "This is not valid JSON"
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="text",
required=False,
),
]
result = service._parse_llm_response(response, field_contexts)
assert result.success is False
assert "JSON parse error" in result.error_message
def test_validate_field_value_text(self, service):
"""Test validating text field value."""
ctx = InferenceFieldContext(
field_key="title",
label="标题",
type="text",
required=False,
)
result = service._validate_field_value(ctx, "测试标题")
assert result == "测试标题"
def test_validate_field_value_number(self, service):
"""Test validating number field value."""
ctx = InferenceFieldContext(
field_key="count",
label="数量",
type="number",
required=False,
)
assert service._validate_field_value(ctx, 42) == 42
assert service._validate_field_value(ctx, "3.14") == 3.14
assert service._validate_field_value(ctx, "invalid") is None
def test_validate_field_value_boolean(self, service):
"""Test validating boolean field value."""
ctx = InferenceFieldContext(
field_key="active",
label="是否激活",
type="boolean",
required=False,
)
assert service._validate_field_value(ctx, True) is True
assert service._validate_field_value(ctx, "true") is True
assert service._validate_field_value(ctx, "false") is False
assert service._validate_field_value(ctx, 1) is True
def test_validate_field_value_enum(self, service):
"""Test validating enum field value."""
ctx = InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=False,
options=["初一", "初二", "初三"],
)
assert service._validate_field_value(ctx, "初一") == "初一"
assert service._validate_field_value(ctx, "高一") is None
def test_validate_field_value_array_enum(self, service):
"""Test validating array_enum field value."""
ctx = InferenceFieldContext(
field_key="tags",
label="标签",
type="array_enum",
required=False,
options=["重点", "难点", "易错"],
)
result = service._validate_field_value(ctx, ["重点", "难点"])
assert result == ["重点", "难点"]
result = service._validate_field_value(ctx, ["重点", "不存在"])
assert result == ["重点"]
result = service._validate_field_value(ctx, "重点")
assert result == ["重点"]
class TestIntegrationScenarios:
"""Test integration scenarios."""
@pytest.fixture
def mock_session(self):
"""Create mock session."""
return AsyncMock()
@pytest.fixture
def service(self, mock_session):
"""Create service instance."""
return MetadataAutoInferenceService(mock_session)
def test_education_scenario(self, service):
"""Test education scenario with grade and subject."""
response = json.dumps({
"inferred_metadata": {
"grade": "初二",
"subject": "物理",
"type": "痛点",
},
"confidence_scores": {
"grade": 0.95,
"subject": 0.90,
"type": 0.85,
}
})
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
InferenceFieldContext(
field_key="subject",
label="学科",
type="enum",
required=True,
options=["语文", "数学", "英语", "物理", "化学"],
),
InferenceFieldContext(
field_key="type",
label="类型",
type="enum",
required=False,
options=["痛点", "重点", "难点"],
),
]
result = service._parse_llm_response(response, field_contexts)
assert result.success is True
assert result.inferred_metadata == {
"grade": "初二",
"subject": "物理",
"type": "痛点",
}
assert result.confidence_scores["grade"] == 0.95
def test_partial_inference(self, service):
"""Test partial inference when some fields cannot be inferred."""
response = json.dumps({
"inferred_metadata": {
"grade": "初一",
},
"confidence_scores": {
"grade": 0.90,
}
})
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
InferenceFieldContext(
field_key="subject",
label="学科",
type="enum",
required=True,
options=["语文", "数学", "英语"],
),
]
result = service._parse_llm_response(response, field_contexts)
assert result.success is True
assert "grade" in result.inferred_metadata
assert "subject" not in result.inferred_metadata
if __name__ == "__main__":
pytest.main([__file__, "-v"])