ai-robot-core/ai-service/tests/test_metadata_auto_inferenc...

444 lines
14 KiB
Python

"""
Unit tests for MetadataAutoInferenceService.
"""
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from dataclasses import dataclass
from app.services.metadata_auto_inference_service import (
AutoInferenceResult,
InferenceFieldContext,
MetadataAutoInferenceService,
)
@dataclass
class MockFieldDefinition:
"""Mock field definition for testing"""
field_key: str
label: str
type: str
required: bool
options: list[str] | None = None
class TestInferenceFieldContext:
"""Test InferenceFieldContext dataclass."""
def test_creation(self):
"""Test creating InferenceFieldContext."""
ctx = InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
)
assert ctx.field_key == "grade"
assert ctx.label == "年级"
assert ctx.type == "enum"
assert ctx.required is True
assert ctx.options == ["初一", "初二", "初三"]
def test_default_values(self):
"""Test default values."""
ctx = InferenceFieldContext(
field_key="test",
label="Test",
type="text",
required=False,
)
assert ctx.options is None
assert ctx.description is None
class TestAutoInferenceResult:
"""Test AutoInferenceResult dataclass."""
def test_success_result(self):
"""Test successful inference result."""
result = AutoInferenceResult(
inferred_metadata={"grade": "初一", "subject": "数学"},
confidence_scores={"grade": 0.95, "subject": 0.85},
raw_response='{"inferred_metadata": {...}}',
success=True,
)
assert result.success is True
assert result.error_message is None
assert result.inferred_metadata["grade"] == "初一"
def test_failure_result(self):
"""Test failed inference result."""
result = AutoInferenceResult(
inferred_metadata={},
confidence_scores={},
raw_response="",
success=False,
error_message="JSON parse error",
)
assert result.success is False
assert result.error_message == "JSON parse error"
class TestMetadataAutoInferenceService:
"""Test MetadataAutoInferenceService functionality."""
@pytest.fixture
def mock_session(self):
"""Create mock session."""
return AsyncMock()
@pytest.fixture
def service(self, mock_session):
"""Create service instance."""
return MetadataAutoInferenceService(mock_session)
def test_build_field_contexts(self, service):
"""Test building field contexts from definitions."""
fields = [
MockFieldDefinition(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
MockFieldDefinition(
field_key="subject",
label="学科",
type="enum",
required=True,
options=["语文", "数学", "英语"],
),
]
contexts = service._build_field_contexts(fields)
assert len(contexts) == 2
assert contexts[0].field_key == "grade"
assert contexts[0].options == ["初一", "初二", "初三"]
assert contexts[1].field_key == "subject"
def test_build_user_prompt(self, service):
"""Test building user prompt."""
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
]
prompt = service._build_user_prompt(
content="这是一道初一数学题",
field_contexts=field_contexts,
)
assert "年级" in prompt
assert "初一, 初二, 初三" in prompt
assert "这是一道初一数学题" in prompt
def test_build_user_prompt_with_existing_metadata(self, service):
"""Test building user prompt with existing metadata."""
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
]
prompt = service._build_user_prompt(
content="这是一道数学题",
field_contexts=field_contexts,
existing_metadata={"grade": "初二"},
)
assert "已有值: 初二" in prompt
def test_extract_json_from_plain_json(self, service):
"""Test extracting JSON from plain JSON response."""
json_str = '{"inferred_metadata": {"grade": "初一"}, "confidence_scores": {"grade": 0.95}}'
result = service._extract_json(json_str)
assert result == json_str
def test_extract_json_from_markdown(self, service):
"""Test extracting JSON from markdown code block."""
markdown = """Here is the result:
```json
{"inferred_metadata": {"grade": "初一"}, "confidence_scores": {"grade": 0.95}}
```
"""
result = service._extract_json(markdown)
assert "inferred_metadata" in result
assert "grade" in result
def test_parse_llm_response_valid(self, service):
"""Test parsing valid LLM response."""
response = json.dumps({
"inferred_metadata": {
"grade": "初一",
"subject": "数学",
},
"confidence_scores": {
"grade": 0.95,
"subject": 0.85,
}
})
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
InferenceFieldContext(
field_key="subject",
label="学科",
type="enum",
required=True,
options=["语文", "数学", "英语"],
),
]
result = service._parse_llm_response(response, field_contexts)
assert result.success is True
assert result.inferred_metadata["grade"] == "初一"
assert result.inferred_metadata["subject"] == "数学"
assert result.confidence_scores["grade"] == 0.95
def test_parse_llm_response_invalid_option(self, service):
"""Test parsing response with invalid enum option."""
response = json.dumps({
"inferred_metadata": {
"grade": "高一", # Not in options
},
"confidence_scores": {
"grade": 0.90,
}
})
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
]
result = service._parse_llm_response(response, field_contexts)
assert result.success is True
assert "grade" not in result.inferred_metadata
def test_parse_llm_response_invalid_json(self, service):
"""Test parsing invalid JSON response."""
response = "This is not valid JSON"
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="text",
required=False,
),
]
result = service._parse_llm_response(response, field_contexts)
assert result.success is False
assert "JSON parse error" in result.error_message
def test_validate_field_value_text(self, service):
"""Test validating text field value."""
ctx = InferenceFieldContext(
field_key="title",
label="标题",
type="text",
required=False,
)
result = service._validate_field_value(ctx, "测试标题")
assert result == "测试标题"
def test_validate_field_value_number(self, service):
"""Test validating number field value."""
ctx = InferenceFieldContext(
field_key="count",
label="数量",
type="number",
required=False,
)
assert service._validate_field_value(ctx, 42) == 42
assert service._validate_field_value(ctx, "3.14") == 3.14
assert service._validate_field_value(ctx, "invalid") is None
def test_validate_field_value_boolean(self, service):
"""Test validating boolean field value."""
ctx = InferenceFieldContext(
field_key="active",
label="是否激活",
type="boolean",
required=False,
)
assert service._validate_field_value(ctx, True) is True
assert service._validate_field_value(ctx, "true") is True
assert service._validate_field_value(ctx, "false") is False
assert service._validate_field_value(ctx, 1) is True
def test_validate_field_value_enum(self, service):
"""Test validating enum field value."""
ctx = InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=False,
options=["初一", "初二", "初三"],
)
assert service._validate_field_value(ctx, "初一") == "初一"
assert service._validate_field_value(ctx, "高一") is None
def test_validate_field_value_array_enum(self, service):
"""Test validating array_enum field value."""
ctx = InferenceFieldContext(
field_key="tags",
label="标签",
type="array_enum",
required=False,
options=["重点", "难点", "易错"],
)
result = service._validate_field_value(ctx, ["重点", "难点"])
assert result == ["重点", "难点"]
result = service._validate_field_value(ctx, ["重点", "不存在"])
assert result == ["重点"]
result = service._validate_field_value(ctx, "重点")
assert result == ["重点"]
class TestIntegrationScenarios:
"""Test integration scenarios."""
@pytest.fixture
def mock_session(self):
"""Create mock session."""
return AsyncMock()
@pytest.fixture
def service(self, mock_session):
"""Create service instance."""
return MetadataAutoInferenceService(mock_session)
def test_education_scenario(self, service):
"""Test education scenario with grade and subject."""
response = json.dumps({
"inferred_metadata": {
"grade": "初二",
"subject": "物理",
"type": "痛点",
},
"confidence_scores": {
"grade": 0.95,
"subject": 0.90,
"type": 0.85,
}
})
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
InferenceFieldContext(
field_key="subject",
label="学科",
type="enum",
required=True,
options=["语文", "数学", "英语", "物理", "化学"],
),
InferenceFieldContext(
field_key="type",
label="类型",
type="enum",
required=False,
options=["痛点", "重点", "难点"],
),
]
result = service._parse_llm_response(response, field_contexts)
assert result.success is True
assert result.inferred_metadata == {
"grade": "初二",
"subject": "物理",
"type": "痛点",
}
assert result.confidence_scores["grade"] == 0.95
def test_partial_inference(self, service):
"""Test partial inference when some fields cannot be inferred."""
response = json.dumps({
"inferred_metadata": {
"grade": "初一",
},
"confidence_scores": {
"grade": 0.90,
}
})
field_contexts = [
InferenceFieldContext(
field_key="grade",
label="年级",
type="enum",
required=True,
options=["初一", "初二", "初三"],
),
InferenceFieldContext(
field_key="subject",
label="学科",
type="enum",
required=True,
options=["语文", "数学", "英语"],
),
]
result = service._parse_llm_response(response, field_contexts)
assert result.success is True
assert "grade" in result.inferred_metadata
assert "subject" not in result.inferred_metadata
if __name__ == "__main__":
pytest.main([__file__, "-v"])