""" Unit tests for MetadataAutoInferenceService. """ import json import pytest from unittest.mock import AsyncMock, MagicMock, patch from dataclasses import dataclass from app.services.metadata_auto_inference_service import ( AutoInferenceResult, InferenceFieldContext, MetadataAutoInferenceService, ) @dataclass class MockFieldDefinition: """Mock field definition for testing""" field_key: str label: str type: str required: bool options: list[str] | None = None class TestInferenceFieldContext: """Test InferenceFieldContext dataclass.""" def test_creation(self): """Test creating InferenceFieldContext.""" ctx = InferenceFieldContext( field_key="grade", label="年级", type="enum", required=True, options=["初一", "初二", "初三"], ) assert ctx.field_key == "grade" assert ctx.label == "年级" assert ctx.type == "enum" assert ctx.required is True assert ctx.options == ["初一", "初二", "初三"] def test_default_values(self): """Test default values.""" ctx = InferenceFieldContext( field_key="test", label="Test", type="text", required=False, ) assert ctx.options is None assert ctx.description is None class TestAutoInferenceResult: """Test AutoInferenceResult dataclass.""" def test_success_result(self): """Test successful inference result.""" result = AutoInferenceResult( inferred_metadata={"grade": "初一", "subject": "数学"}, confidence_scores={"grade": 0.95, "subject": 0.85}, raw_response='{"inferred_metadata": {...}}', success=True, ) assert result.success is True assert result.error_message is None assert result.inferred_metadata["grade"] == "初一" def test_failure_result(self): """Test failed inference result.""" result = AutoInferenceResult( inferred_metadata={}, confidence_scores={}, raw_response="", success=False, error_message="JSON parse error", ) assert result.success is False assert result.error_message == "JSON parse error" class TestMetadataAutoInferenceService: """Test MetadataAutoInferenceService functionality.""" @pytest.fixture def mock_session(self): """Create mock session.""" return AsyncMock() @pytest.fixture def service(self, mock_session): """Create service instance.""" return MetadataAutoInferenceService(mock_session) def test_build_field_contexts(self, service): """Test building field contexts from definitions.""" fields = [ MockFieldDefinition( field_key="grade", label="年级", type="enum", required=True, options=["初一", "初二", "初三"], ), MockFieldDefinition( field_key="subject", label="学科", type="enum", required=True, options=["语文", "数学", "英语"], ), ] contexts = service._build_field_contexts(fields) assert len(contexts) == 2 assert contexts[0].field_key == "grade" assert contexts[0].options == ["初一", "初二", "初三"] assert contexts[1].field_key == "subject" def test_build_user_prompt(self, service): """Test building user prompt.""" field_contexts = [ InferenceFieldContext( field_key="grade", label="年级", type="enum", required=True, options=["初一", "初二", "初三"], ), ] prompt = service._build_user_prompt( content="这是一道初一数学题", field_contexts=field_contexts, ) assert "年级" in prompt assert "初一, 初二, 初三" in prompt assert "这是一道初一数学题" in prompt def test_build_user_prompt_with_existing_metadata(self, service): """Test building user prompt with existing metadata.""" field_contexts = [ InferenceFieldContext( field_key="grade", label="年级", type="enum", required=True, options=["初一", "初二", "初三"], ), ] prompt = service._build_user_prompt( content="这是一道数学题", field_contexts=field_contexts, existing_metadata={"grade": "初二"}, ) assert "已有值: 初二" in prompt def test_extract_json_from_plain_json(self, service): """Test extracting JSON from plain JSON response.""" json_str = '{"inferred_metadata": {"grade": "初一"}, "confidence_scores": {"grade": 0.95}}' result = service._extract_json(json_str) assert result == json_str def test_extract_json_from_markdown(self, service): """Test extracting JSON from markdown code block.""" markdown = """Here is the result: ```json {"inferred_metadata": {"grade": "初一"}, "confidence_scores": {"grade": 0.95}} ``` """ result = service._extract_json(markdown) assert "inferred_metadata" in result assert "grade" in result def test_parse_llm_response_valid(self, service): """Test parsing valid LLM response.""" response = json.dumps({ "inferred_metadata": { "grade": "初一", "subject": "数学", }, "confidence_scores": { "grade": 0.95, "subject": 0.85, } }) field_contexts = [ InferenceFieldContext( field_key="grade", label="年级", type="enum", required=True, options=["初一", "初二", "初三"], ), InferenceFieldContext( field_key="subject", label="学科", type="enum", required=True, options=["语文", "数学", "英语"], ), ] result = service._parse_llm_response(response, field_contexts) assert result.success is True assert result.inferred_metadata["grade"] == "初一" assert result.inferred_metadata["subject"] == "数学" assert result.confidence_scores["grade"] == 0.95 def test_parse_llm_response_invalid_option(self, service): """Test parsing response with invalid enum option.""" response = json.dumps({ "inferred_metadata": { "grade": "高一", # Not in options }, "confidence_scores": { "grade": 0.90, } }) field_contexts = [ InferenceFieldContext( field_key="grade", label="年级", type="enum", required=True, options=["初一", "初二", "初三"], ), ] result = service._parse_llm_response(response, field_contexts) assert result.success is True assert "grade" not in result.inferred_metadata def test_parse_llm_response_invalid_json(self, service): """Test parsing invalid JSON response.""" response = "This is not valid JSON" field_contexts = [ InferenceFieldContext( field_key="grade", label="年级", type="text", required=False, ), ] result = service._parse_llm_response(response, field_contexts) assert result.success is False assert "JSON parse error" in result.error_message def test_validate_field_value_text(self, service): """Test validating text field value.""" ctx = InferenceFieldContext( field_key="title", label="标题", type="text", required=False, ) result = service._validate_field_value(ctx, "测试标题") assert result == "测试标题" def test_validate_field_value_number(self, service): """Test validating number field value.""" ctx = InferenceFieldContext( field_key="count", label="数量", type="number", required=False, ) assert service._validate_field_value(ctx, 42) == 42 assert service._validate_field_value(ctx, "3.14") == 3.14 assert service._validate_field_value(ctx, "invalid") is None def test_validate_field_value_boolean(self, service): """Test validating boolean field value.""" ctx = InferenceFieldContext( field_key="active", label="是否激活", type="boolean", required=False, ) assert service._validate_field_value(ctx, True) is True assert service._validate_field_value(ctx, "true") is True assert service._validate_field_value(ctx, "false") is False assert service._validate_field_value(ctx, 1) is True def test_validate_field_value_enum(self, service): """Test validating enum field value.""" ctx = InferenceFieldContext( field_key="grade", label="年级", type="enum", required=False, options=["初一", "初二", "初三"], ) assert service._validate_field_value(ctx, "初一") == "初一" assert service._validate_field_value(ctx, "高一") is None def test_validate_field_value_array_enum(self, service): """Test validating array_enum field value.""" ctx = InferenceFieldContext( field_key="tags", label="标签", type="array_enum", required=False, options=["重点", "难点", "易错"], ) result = service._validate_field_value(ctx, ["重点", "难点"]) assert result == ["重点", "难点"] result = service._validate_field_value(ctx, ["重点", "不存在"]) assert result == ["重点"] result = service._validate_field_value(ctx, "重点") assert result == ["重点"] class TestIntegrationScenarios: """Test integration scenarios.""" @pytest.fixture def mock_session(self): """Create mock session.""" return AsyncMock() @pytest.fixture def service(self, mock_session): """Create service instance.""" return MetadataAutoInferenceService(mock_session) def test_education_scenario(self, service): """Test education scenario with grade and subject.""" response = json.dumps({ "inferred_metadata": { "grade": "初二", "subject": "物理", "type": "痛点", }, "confidence_scores": { "grade": 0.95, "subject": 0.90, "type": 0.85, } }) field_contexts = [ InferenceFieldContext( field_key="grade", label="年级", type="enum", required=True, options=["初一", "初二", "初三"], ), InferenceFieldContext( field_key="subject", label="学科", type="enum", required=True, options=["语文", "数学", "英语", "物理", "化学"], ), InferenceFieldContext( field_key="type", label="类型", type="enum", required=False, options=["痛点", "重点", "难点"], ), ] result = service._parse_llm_response(response, field_contexts) assert result.success is True assert result.inferred_metadata == { "grade": "初二", "subject": "物理", "type": "痛点", } assert result.confidence_scores["grade"] == 0.95 def test_partial_inference(self, service): """Test partial inference when some fields cannot be inferred.""" response = json.dumps({ "inferred_metadata": { "grade": "初一", }, "confidence_scores": { "grade": 0.90, } }) field_contexts = [ InferenceFieldContext( field_key="grade", label="年级", type="enum", required=True, options=["初一", "初二", "初三"], ), InferenceFieldContext( field_key="subject", label="学科", type="enum", required=True, options=["语文", "数学", "英语"], ), ] result = service._parse_llm_response(response, field_contexts) assert result.success is True assert "grade" in result.inferred_metadata assert "subject" not in result.inferred_metadata if __name__ == "__main__": pytest.main([__file__, "-v"])