444 lines
14 KiB
Python
444 lines
14 KiB
Python
|
|
"""
|
||
|
|
Unit tests for MetadataAutoInferenceService.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import json
|
||
|
|
import pytest
|
||
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
|
from dataclasses import dataclass
|
||
|
|
|
||
|
|
from app.services.metadata_auto_inference_service import (
|
||
|
|
AutoInferenceResult,
|
||
|
|
InferenceFieldContext,
|
||
|
|
MetadataAutoInferenceService,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class MockFieldDefinition:
|
||
|
|
"""Mock field definition for testing"""
|
||
|
|
field_key: str
|
||
|
|
label: str
|
||
|
|
type: str
|
||
|
|
required: bool
|
||
|
|
options: list[str] | None = None
|
||
|
|
|
||
|
|
|
||
|
|
class TestInferenceFieldContext:
|
||
|
|
"""Test InferenceFieldContext dataclass."""
|
||
|
|
|
||
|
|
def test_creation(self):
|
||
|
|
"""Test creating InferenceFieldContext."""
|
||
|
|
ctx = InferenceFieldContext(
|
||
|
|
field_key="grade",
|
||
|
|
label="年级",
|
||
|
|
type="enum",
|
||
|
|
required=True,
|
||
|
|
options=["初一", "初二", "初三"],
|
||
|
|
)
|
||
|
|
|
||
|
|
assert ctx.field_key == "grade"
|
||
|
|
assert ctx.label == "年级"
|
||
|
|
assert ctx.type == "enum"
|
||
|
|
assert ctx.required is True
|
||
|
|
assert ctx.options == ["初一", "初二", "初三"]
|
||
|
|
|
||
|
|
def test_default_values(self):
|
||
|
|
"""Test default values."""
|
||
|
|
ctx = InferenceFieldContext(
|
||
|
|
field_key="test",
|
||
|
|
label="Test",
|
||
|
|
type="text",
|
||
|
|
required=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert ctx.options is None
|
||
|
|
assert ctx.description is None
|
||
|
|
|
||
|
|
|
||
|
|
class TestAutoInferenceResult:
|
||
|
|
"""Test AutoInferenceResult dataclass."""
|
||
|
|
|
||
|
|
def test_success_result(self):
|
||
|
|
"""Test successful inference result."""
|
||
|
|
result = AutoInferenceResult(
|
||
|
|
inferred_metadata={"grade": "初一", "subject": "数学"},
|
||
|
|
confidence_scores={"grade": 0.95, "subject": 0.85},
|
||
|
|
raw_response='{"inferred_metadata": {...}}',
|
||
|
|
success=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert result.success is True
|
||
|
|
assert result.error_message is None
|
||
|
|
assert result.inferred_metadata["grade"] == "初一"
|
||
|
|
|
||
|
|
def test_failure_result(self):
|
||
|
|
"""Test failed inference result."""
|
||
|
|
result = AutoInferenceResult(
|
||
|
|
inferred_metadata={},
|
||
|
|
confidence_scores={},
|
||
|
|
raw_response="",
|
||
|
|
success=False,
|
||
|
|
error_message="JSON parse error",
|
||
|
|
)
|
||
|
|
|
||
|
|
assert result.success is False
|
||
|
|
assert result.error_message == "JSON parse error"
|
||
|
|
|
||
|
|
|
||
|
|
class TestMetadataAutoInferenceService:
|
||
|
|
"""Test MetadataAutoInferenceService functionality."""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_session(self):
|
||
|
|
"""Create mock session."""
|
||
|
|
return AsyncMock()
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def service(self, mock_session):
|
||
|
|
"""Create service instance."""
|
||
|
|
return MetadataAutoInferenceService(mock_session)
|
||
|
|
|
||
|
|
def test_build_field_contexts(self, service):
|
||
|
|
"""Test building field contexts from definitions."""
|
||
|
|
fields = [
|
||
|
|
MockFieldDefinition(
|
||
|
|
field_key="grade",
|
||
|
|
label="年级",
|
||
|
|
type="enum",
|
||
|
|
required=True,
|
||
|
|
options=["初一", "初二", "初三"],
|
||
|
|
),
|
||
|
|
MockFieldDefinition(
|
||
|
|
field_key="subject",
|
||
|
|
label="学科",
|
||
|
|
type="enum",
|
||
|
|
required=True,
|
||
|
|
options=["语文", "数学", "英语"],
|
||
|
|
),
|
||
|
|
]
|
||
|
|
|
||
|
|
contexts = service._build_field_contexts(fields)
|
||
|
|
|
||
|
|
assert len(contexts) == 2
|
||
|
|
assert contexts[0].field_key == "grade"
|
||
|
|
assert contexts[0].options == ["初一", "初二", "初三"]
|
||
|
|
assert contexts[1].field_key == "subject"
|
||
|
|
|
||
|
|
def test_build_user_prompt(self, service):
|
||
|
|
"""Test building user prompt."""
|
||
|
|
field_contexts = [
|
||
|
|
InferenceFieldContext(
|
||
|
|
field_key="grade",
|
||
|
|
label="年级",
|
||
|
|
type="enum",
|
||
|
|
required=True,
|
||
|
|
options=["初一", "初二", "初三"],
|
||
|
|
),
|
||
|
|
]
|
||
|
|
|
||
|
|
prompt = service._build_user_prompt(
|
||
|
|
content="这是一道初一数学题",
|
||
|
|
field_contexts=field_contexts,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert "年级" in prompt
|
||
|
|
assert "初一, 初二, 初三" in prompt
|
||
|
|
assert "这是一道初一数学题" in prompt
|
||
|
|
|
||
|
|
def test_build_user_prompt_with_existing_metadata(self, service):
|
||
|
|
"""Test building user prompt with existing metadata."""
|
||
|
|
field_contexts = [
|
||
|
|
InferenceFieldContext(
|
||
|
|
field_key="grade",
|
||
|
|
label="年级",
|
||
|
|
type="enum",
|
||
|
|
required=True,
|
||
|
|
options=["初一", "初二", "初三"],
|
||
|
|
),
|
||
|
|
]
|
||
|
|
|
||
|
|
prompt = service._build_user_prompt(
|
||
|
|
content="这是一道数学题",
|
||
|
|
field_contexts=field_contexts,
|
||
|
|
existing_metadata={"grade": "初二"},
|
||
|
|
)
|
||
|
|
|
||
|
|
assert "已有值: 初二" in prompt
|
||
|
|
|
||
|
|
def test_extract_json_from_plain_json(self, service):
|
||
|
|
"""Test extracting JSON from plain JSON response."""
|
||
|
|
json_str = '{"inferred_metadata": {"grade": "初一"}, "confidence_scores": {"grade": 0.95}}'
|
||
|
|
|
||
|
|
result = service._extract_json(json_str)
|
||
|
|
|
||
|
|
assert result == json_str
|
||
|
|
|
||
|
|
def test_extract_json_from_markdown(self, service):
|
||
|
|
"""Test extracting JSON from markdown code block."""
|
||
|
|
markdown = """Here is the result:
|
||
|
|
|
||
|
|
```json
|
||
|
|
{"inferred_metadata": {"grade": "初一"}, "confidence_scores": {"grade": 0.95}}
|
||
|
|
```
|
||
|
|
"""
|
||
|
|
|
||
|
|
result = service._extract_json(markdown)
|
||
|
|
|
||
|
|
assert "inferred_metadata" in result
|
||
|
|
assert "grade" in result
|
||
|
|
|
||
|
|
def test_parse_llm_response_valid(self, service):
|
||
|
|
"""Test parsing valid LLM response."""
|
||
|
|
response = json.dumps({
|
||
|
|
"inferred_metadata": {
|
||
|
|
"grade": "初一",
|
||
|
|
"subject": "数学",
|
||
|
|
},
|
||
|
|
"confidence_scores": {
|
||
|
|
"grade": 0.95,
|
||
|
|
"subject": 0.85,
|
||
|
|
}
|
||
|
|
})
|
||
|
|
|
||
|
|
field_contexts = [
|
||
|
|
InferenceFieldContext(
|
||
|
|
field_key="grade",
|
||
|
|
label="年级",
|
||
|
|
type="enum",
|
||
|
|
required=True,
|
||
|
|
options=["初一", "初二", "初三"],
|
||
|
|
),
|
||
|
|
InferenceFieldContext(
|
||
|
|
field_key="subject",
|
||
|
|
label="学科",
|
||
|
|
type="enum",
|
||
|
|
required=True,
|
||
|
|
options=["语文", "数学", "英语"],
|
||
|
|
),
|
||
|
|
]
|
||
|
|
|
||
|
|
result = service._parse_llm_response(response, field_contexts)
|
||
|
|
|
||
|
|
assert result.success is True
|
||
|
|
assert result.inferred_metadata["grade"] == "初一"
|
||
|
|
assert result.inferred_metadata["subject"] == "数学"
|
||
|
|
assert result.confidence_scores["grade"] == 0.95
|
||
|
|
|
||
|
|
def test_parse_llm_response_invalid_option(self, service):
|
||
|
|
"""Test parsing response with invalid enum option."""
|
||
|
|
response = json.dumps({
|
||
|
|
"inferred_metadata": {
|
||
|
|
"grade": "高一", # Not in options
|
||
|
|
},
|
||
|
|
"confidence_scores": {
|
||
|
|
"grade": 0.90,
|
||
|
|
}
|
||
|
|
})
|
||
|
|
|
||
|
|
field_contexts = [
|
||
|
|
InferenceFieldContext(
|
||
|
|
field_key="grade",
|
||
|
|
label="年级",
|
||
|
|
type="enum",
|
||
|
|
required=True,
|
||
|
|
options=["初一", "初二", "初三"],
|
||
|
|
),
|
||
|
|
]
|
||
|
|
|
||
|
|
result = service._parse_llm_response(response, field_contexts)
|
||
|
|
|
||
|
|
assert result.success is True
|
||
|
|
assert "grade" not in result.inferred_metadata
|
||
|
|
|
||
|
|
def test_parse_llm_response_invalid_json(self, service):
|
||
|
|
"""Test parsing invalid JSON response."""
|
||
|
|
response = "This is not valid JSON"
|
||
|
|
|
||
|
|
field_contexts = [
|
||
|
|
InferenceFieldContext(
|
||
|
|
field_key="grade",
|
||
|
|
label="年级",
|
||
|
|
type="text",
|
||
|
|
required=False,
|
||
|
|
),
|
||
|
|
]
|
||
|
|
|
||
|
|
result = service._parse_llm_response(response, field_contexts)
|
||
|
|
|
||
|
|
assert result.success is False
|
||
|
|
assert "JSON parse error" in result.error_message
|
||
|
|
|
||
|
|
def test_validate_field_value_text(self, service):
|
||
|
|
"""Test validating text field value."""
|
||
|
|
ctx = InferenceFieldContext(
|
||
|
|
field_key="title",
|
||
|
|
label="标题",
|
||
|
|
type="text",
|
||
|
|
required=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
result = service._validate_field_value(ctx, "测试标题")
|
||
|
|
|
||
|
|
assert result == "测试标题"
|
||
|
|
|
||
|
|
def test_validate_field_value_number(self, service):
|
||
|
|
"""Test validating number field value."""
|
||
|
|
ctx = InferenceFieldContext(
|
||
|
|
field_key="count",
|
||
|
|
label="数量",
|
||
|
|
type="number",
|
||
|
|
required=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert service._validate_field_value(ctx, 42) == 42
|
||
|
|
assert service._validate_field_value(ctx, "3.14") == 3.14
|
||
|
|
assert service._validate_field_value(ctx, "invalid") is None
|
||
|
|
|
||
|
|
def test_validate_field_value_boolean(self, service):
|
||
|
|
"""Test validating boolean field value."""
|
||
|
|
ctx = InferenceFieldContext(
|
||
|
|
field_key="active",
|
||
|
|
label="是否激活",
|
||
|
|
type="boolean",
|
||
|
|
required=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert service._validate_field_value(ctx, True) is True
|
||
|
|
assert service._validate_field_value(ctx, "true") is True
|
||
|
|
assert service._validate_field_value(ctx, "false") is False
|
||
|
|
assert service._validate_field_value(ctx, 1) is True
|
||
|
|
|
||
|
|
def test_validate_field_value_enum(self, service):
|
||
|
|
"""Test validating enum field value."""
|
||
|
|
ctx = InferenceFieldContext(
|
||
|
|
field_key="grade",
|
||
|
|
label="年级",
|
||
|
|
type="enum",
|
||
|
|
required=False,
|
||
|
|
options=["初一", "初二", "初三"],
|
||
|
|
)
|
||
|
|
|
||
|
|
assert service._validate_field_value(ctx, "初一") == "初一"
|
||
|
|
assert service._validate_field_value(ctx, "高一") is None
|
||
|
|
|
||
|
|
def test_validate_field_value_array_enum(self, service):
|
||
|
|
"""Test validating array_enum field value."""
|
||
|
|
ctx = InferenceFieldContext(
|
||
|
|
field_key="tags",
|
||
|
|
label="标签",
|
||
|
|
type="array_enum",
|
||
|
|
required=False,
|
||
|
|
options=["重点", "难点", "易错"],
|
||
|
|
)
|
||
|
|
|
||
|
|
result = service._validate_field_value(ctx, ["重点", "难点"])
|
||
|
|
assert result == ["重点", "难点"]
|
||
|
|
|
||
|
|
result = service._validate_field_value(ctx, ["重点", "不存在"])
|
||
|
|
assert result == ["重点"]
|
||
|
|
|
||
|
|
result = service._validate_field_value(ctx, "重点")
|
||
|
|
assert result == ["重点"]
|
||
|
|
|
||
|
|
|
||
|
|
class TestIntegrationScenarios:
|
||
|
|
"""Test integration scenarios."""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_session(self):
|
||
|
|
"""Create mock session."""
|
||
|
|
return AsyncMock()
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def service(self, mock_session):
|
||
|
|
"""Create service instance."""
|
||
|
|
return MetadataAutoInferenceService(mock_session)
|
||
|
|
|
||
|
|
def test_education_scenario(self, service):
|
||
|
|
"""Test education scenario with grade and subject."""
|
||
|
|
response = json.dumps({
|
||
|
|
"inferred_metadata": {
|
||
|
|
"grade": "初二",
|
||
|
|
"subject": "物理",
|
||
|
|
"type": "痛点",
|
||
|
|
},
|
||
|
|
"confidence_scores": {
|
||
|
|
"grade": 0.95,
|
||
|
|
"subject": 0.90,
|
||
|
|
"type": 0.85,
|
||
|
|
}
|
||
|
|
})
|
||
|
|
|
||
|
|
field_contexts = [
|
||
|
|
InferenceFieldContext(
|
||
|
|
field_key="grade",
|
||
|
|
label="年级",
|
||
|
|
type="enum",
|
||
|
|
required=True,
|
||
|
|
options=["初一", "初二", "初三"],
|
||
|
|
),
|
||
|
|
InferenceFieldContext(
|
||
|
|
field_key="subject",
|
||
|
|
label="学科",
|
||
|
|
type="enum",
|
||
|
|
required=True,
|
||
|
|
options=["语文", "数学", "英语", "物理", "化学"],
|
||
|
|
),
|
||
|
|
InferenceFieldContext(
|
||
|
|
field_key="type",
|
||
|
|
label="类型",
|
||
|
|
type="enum",
|
||
|
|
required=False,
|
||
|
|
options=["痛点", "重点", "难点"],
|
||
|
|
),
|
||
|
|
]
|
||
|
|
|
||
|
|
result = service._parse_llm_response(response, field_contexts)
|
||
|
|
|
||
|
|
assert result.success is True
|
||
|
|
assert result.inferred_metadata == {
|
||
|
|
"grade": "初二",
|
||
|
|
"subject": "物理",
|
||
|
|
"type": "痛点",
|
||
|
|
}
|
||
|
|
assert result.confidence_scores["grade"] == 0.95
|
||
|
|
|
||
|
|
def test_partial_inference(self, service):
|
||
|
|
"""Test partial inference when some fields cannot be inferred."""
|
||
|
|
response = json.dumps({
|
||
|
|
"inferred_metadata": {
|
||
|
|
"grade": "初一",
|
||
|
|
},
|
||
|
|
"confidence_scores": {
|
||
|
|
"grade": 0.90,
|
||
|
|
}
|
||
|
|
})
|
||
|
|
|
||
|
|
field_contexts = [
|
||
|
|
InferenceFieldContext(
|
||
|
|
field_key="grade",
|
||
|
|
label="年级",
|
||
|
|
type="enum",
|
||
|
|
required=True,
|
||
|
|
options=["初一", "初二", "初三"],
|
||
|
|
),
|
||
|
|
InferenceFieldContext(
|
||
|
|
field_key="subject",
|
||
|
|
label="学科",
|
||
|
|
type="enum",
|
||
|
|
required=True,
|
||
|
|
options=["语文", "数学", "英语"],
|
||
|
|
),
|
||
|
|
]
|
||
|
|
|
||
|
|
result = service._parse_llm_response(response, field_contexts)
|
||
|
|
|
||
|
|
assert result.success is True
|
||
|
|
assert "grade" in result.inferred_metadata
|
||
|
|
assert "subject" not in result.inferred_metadata
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
pytest.main([__file__, "-v"])
|