""" Tests for Slot Extraction Integration. [AC-MRS-SLOT-EXTRACT-01] slot extraction 集成测试 """ from unittest.mock import AsyncMock, MagicMock, patch import pytest from app.models.mid.schemas import SlotSource from app.services.mid.slot_extraction_integration import ( ExtractionResult, ExtractionTrace, SlotExtractionIntegration, integrate_slot_extraction, ) from app.services.mid.slot_strategy_executor import ( StrategyChainResult, StrategyStepResult, ) class TestExtractionTrace: """ExtractionTrace 测试""" def test_init(self): """测试初始化""" trace = ExtractionTrace(slot_key="region") assert trace.slot_key == "region" assert trace.strategy is None assert trace.validation_passed is False def test_to_dict(self): """测试转换为字典""" trace = ExtractionTrace( slot_key="region", strategy="rule", extracted_value="北京", validation_passed=True, final_value="北京", execution_time_ms=10.5, ) d = trace.to_dict() assert d["slot_key"] == "region" assert d["strategy"] == "rule" assert d["extracted_value"] == "北京" assert d["validation_passed"] is True class TestExtractionResult: """ExtractionResult 测试""" def test_init(self): """测试初始化""" result = ExtractionResult() assert result.success is False assert result.extracted_slots == {} assert result.failed_slots == [] def test_to_dict(self): """测试转换为字典""" result = ExtractionResult( success=True, extracted_slots={"region": "北京"}, failed_slots=["product"], traces=[ExtractionTrace(slot_key="region")], total_execution_time_ms=50.0, ask_back_triggered=True, ask_back_prompts=["请提供产品信息"], ) d = result.to_dict() assert d["success"] is True assert d["extracted_slots"] == {"region": "北京"} assert d["failed_slots"] == ["product"] assert d["ask_back_triggered"] is True class TestSlotExtractionIntegration: """SlotExtractionIntegration 测试""" @pytest.fixture def mock_session(self): """创建 mock session""" return AsyncMock() @pytest.fixture def integration(self, mock_session): """创建集成实例""" return SlotExtractionIntegration( session=mock_session, tenant_id="tenant_1", session_id="session_1", ) def test_default_strategies(self, integration): """测试默认策略""" assert integration.DEFAULT_STRATEGIES == ["rule", "llm"] def test_get_source_for_strategy(self, integration): """测试策略到来源的映射""" assert integration._get_source_for_strategy("rule") == SlotSource.RULE_EXTRACTED.value assert integration._get_source_for_strategy("llm") == SlotSource.LLM_INFERRED.value assert integration._get_source_for_strategy("user_input") == SlotSource.USER_CONFIRMED.value def test_get_confidence_for_source(self, integration): """测试来源到置信度的映射""" assert integration._get_confidence_for_source(SlotSource.USER_CONFIRMED.value) == 1.0 assert integration._get_confidence_for_source(SlotSource.RULE_EXTRACTED.value) == 0.9 assert integration._get_confidence_for_source(SlotSource.LLM_INFERRED.value) == 0.7 @pytest.mark.asyncio async def test_extract_and_fill_no_target_slots(self, integration): """测试没有目标槽位""" result = await integration.extract_and_fill( user_input="测试输入", target_slots=[], ) assert result.success is True assert result.extracted_slots == {} @pytest.mark.asyncio async def test_extract_and_fill_slot_not_found(self, integration): """测试槽位定义不存在""" integration._slot_def_service.get_slot_definition_by_key = AsyncMock( return_value=None ) integration._slot_manager.get_ask_back_prompt = AsyncMock(return_value=None) result = await integration.extract_and_fill( user_input="测试输入", target_slots=["unknown_slot"], ) assert result.success is False assert "unknown_slot" in result.failed_slots assert result.traces[0].failure_reason == "Slot definition not found" @pytest.mark.asyncio async def test_extract_and_fill_extraction_success(self, integration): """测试提取成功""" mock_slot_def = MagicMock() mock_slot_def.type = "string" mock_slot_def.validation_rule = None mock_slot_def.ask_back_prompt = "请提供地区" integration._slot_def_service.get_slot_definition_by_key = AsyncMock( return_value=mock_slot_def ) with patch.object(integration._strategy_executor, 'execute_chain') as mock_chain: mock_chain.return_value = StrategyChainResult( slot_key="region", success=True, final_value="北京", final_strategy="rule", ) with patch.object(integration, '_get_backfill_service') as mock_backfill_svc: mock_backfill = AsyncMock() mock_backfill.backfill_single_slot = AsyncMock( return_value=MagicMock( is_success=lambda: True, normalized_value="北京", error_message=None, ) ) mock_backfill_svc.return_value = mock_backfill with patch.object(integration, '_save_extracted_slots', new_callable=AsyncMock): result = await integration.extract_and_fill( user_input="我想查询北京的产品", target_slots=["region"], ) assert result.success is True assert "region" in result.extracted_slots assert result.extracted_slots["region"] == "北京" @pytest.mark.asyncio async def test_extract_and_fill_extraction_failed(self, integration): """测试提取失败""" mock_slot_def = MagicMock() mock_slot_def.type = "string" mock_slot_def.validation_rule = None mock_slot_def.ask_back_prompt = "请提供地区" integration._slot_def_service.get_slot_definition_by_key = AsyncMock( return_value=mock_slot_def ) with patch.object(integration._strategy_executor, 'execute_chain') as mock_chain: mock_chain.return_value = StrategyChainResult( slot_key="region", success=False, steps=[ StrategyStepResult( strategy="rule", success=False, failure_reason="无法提取", ) ], ) integration._slot_manager.get_ask_back_prompt = AsyncMock( return_value="请提供地区" ) result = await integration.extract_and_fill( user_input="测试输入", target_slots=["region"], ) assert result.success is False assert "region" in result.failed_slots assert result.ask_back_triggered is True assert "请提供地区" in result.ask_back_prompts @pytest.mark.asyncio async def test_get_missing_required_slots_from_state(self, integration): """测试从状态获取缺失槽位""" from app.services.mid.slot_state_aggregator import SlotState slot_state = SlotState() slot_state.missing_required_slots = [ {"slot_key": "region"}, {"slot_key": "product"}, ] result = await integration._get_missing_required_slots(slot_state) assert "region" in result assert "product" in result @pytest.mark.asyncio async def test_get_missing_required_slots_from_db(self, integration): """测试从数据库获取缺失槽位""" mock_defs = [ MagicMock(slot_key="region"), MagicMock(slot_key="product"), ] integration._slot_def_service.list_slot_definitions = AsyncMock( return_value=mock_defs ) result = await integration._get_missing_required_slots(None) assert "region" in result assert "product" in result class TestIntegrateSlotExtraction: """integrate_slot_extraction 便捷函数测试""" @pytest.mark.asyncio async def test_integrate(self): """测试便捷函数""" mock_session = AsyncMock() with patch('app.services.mid.slot_extraction_integration.SlotExtractionIntegration') as mock_cls: mock_instance = AsyncMock() mock_instance.extract_and_fill = AsyncMock( return_value=ExtractionResult(success=True, extracted_slots={"region": "北京"}) ) mock_cls.return_value = mock_instance result = await integrate_slot_extraction( session=mock_session, tenant_id="tenant_1", session_id="session_1", user_input="我想查询北京的产品", ) assert result.success is True assert "region" in result.extracted_slots class TestExtractionTraceFlow: """提取追踪流程测试""" @pytest.fixture def integration(self): """创建集成实例""" mock_session = AsyncMock() return SlotExtractionIntegration( session=mock_session, tenant_id="tenant_1", session_id="session_1", ) @pytest.mark.asyncio async def test_full_extraction_flow(self, integration): """测试完整提取流程""" mock_slot_def = MagicMock() mock_slot_def.type = "string" mock_slot_def.validation_rule = None mock_slot_def.ask_back_prompt = None integration._slot_def_service.get_slot_definition_by_key = AsyncMock( return_value=mock_slot_def ) with patch.object(integration._strategy_executor, 'execute_chain') as mock_chain: mock_chain.return_value = StrategyChainResult( slot_key="region", success=True, final_value="北京", final_strategy="rule", ) with patch.object(integration, '_get_backfill_service') as mock_backfill_svc: mock_backfill = AsyncMock() mock_backfill.backfill_single_slot = AsyncMock( return_value=MagicMock( is_success=lambda: True, normalized_value="北京", error_message=None, ) ) mock_backfill_svc.return_value = mock_backfill with patch.object(integration, '_save_extracted_slots', new_callable=AsyncMock): result = await integration.extract_and_fill( user_input="北京", target_slots=["region"], ) assert len(result.traces) == 1 trace = result.traces[0] assert trace.slot_key == "region" assert trace.strategy == "rule" assert trace.extracted_value == "北京" assert trace.validation_passed is True assert trace.final_value == "北京"