""" Tests for Slot Backfill Service. [AC-MRS-SLOT-BACKFILL-01] 槽位值回填确认测试 """ from unittest.mock import AsyncMock, MagicMock, patch import pytest from app.models.mid.schemas import SlotSource from app.services.mid.slot_backfill_service import ( BackfillResult, BackfillStatus, BatchBackfillResult, SlotBackfillService, create_slot_backfill_service, ) from app.services.mid.slot_manager import SlotWriteResult from app.services.mid.slot_strategy_executor import ( StrategyChainResult, StrategyStepResult, ) class TestBackfillResult: """BackfillResult 测试""" def test_is_success(self): """测试成功判断""" result = BackfillResult(status=BackfillStatus.SUCCESS) assert result.is_success() is True result = BackfillResult(status=BackfillStatus.VALIDATION_FAILED) assert result.is_success() is False def test_needs_ask_back(self): """测试需要追问判断""" result = BackfillResult(status=BackfillStatus.VALIDATION_FAILED) assert result.needs_ask_back() is True result = BackfillResult(status=BackfillStatus.EXTRACTION_FAILED) assert result.needs_ask_back() is True result = BackfillResult(status=BackfillStatus.SUCCESS) assert result.needs_ask_back() is False def test_needs_confirmation(self): """测试需要确认判断""" result = BackfillResult(status=BackfillStatus.NEEDS_CONFIRMATION) assert result.needs_confirmation() is True result = BackfillResult(status=BackfillStatus.SUCCESS) assert result.needs_confirmation() is False def test_to_dict(self): """测试转换为字典""" result = BackfillResult( status=BackfillStatus.SUCCESS, slot_key="region", value="北京", normalized_value="北京", source="user_confirmed", confidence=1.0, ) d = result.to_dict() assert d["status"] == "success" assert d["slot_key"] == "region" assert d["value"] == "北京" assert d["source"] == "user_confirmed" class TestBatchBackfillResult: """BatchBackfillResult 测试""" def test_add_result(self): """测试添加结果""" batch = BatchBackfillResult() batch.add_result(BackfillResult(status=BackfillStatus.SUCCESS, slot_key="region")) batch.add_result(BackfillResult(status=BackfillStatus.VALIDATION_FAILED, slot_key="product")) batch.add_result(BackfillResult(status=BackfillStatus.NEEDS_CONFIRMATION, slot_key="grade")) assert batch.success_count == 1 assert batch.failed_count == 1 assert batch.confirmation_needed_count == 1 def test_get_ask_back_prompts(self): """测试获取追问提示""" batch = BatchBackfillResult() batch.add_result(BackfillResult( status=BackfillStatus.VALIDATION_FAILED, ask_back_prompt="请重新输入", )) batch.add_result(BackfillResult( status=BackfillStatus.SUCCESS, )) batch.add_result(BackfillResult( status=BackfillStatus.EXTRACTION_FAILED, ask_back_prompt="无法识别,请重试", )) prompts = batch.get_ask_back_prompts() assert len(prompts) == 2 assert "请重新输入" in prompts assert "无法识别,请重试" in prompts def test_get_confirmation_prompts(self): """测试获取确认提示""" batch = BatchBackfillResult() batch.add_result(BackfillResult( status=BackfillStatus.NEEDS_CONFIRMATION, confirmation_prompt="我理解您说的是「北京」,对吗?", )) batch.add_result(BackfillResult( status=BackfillStatus.SUCCESS, )) prompts = batch.get_confirmation_prompts() assert len(prompts) == 1 assert "北京" in prompts[0] class TestSlotBackfillService: """SlotBackfillService 测试""" @pytest.fixture def mock_session(self): """创建 mock session""" return AsyncMock() @pytest.fixture def mock_slot_manager(self): """创建 mock slot manager""" manager = MagicMock() manager.write_slot = AsyncMock() manager.get_ask_back_prompt = AsyncMock(return_value="请提供信息") return manager @pytest.fixture def service(self, mock_session, mock_slot_manager): """创建服务实例""" return SlotBackfillService( session=mock_session, tenant_id="tenant_1", session_id="session_1", slot_manager=mock_slot_manager, ) def test_confidence_thresholds(self, service): """测试置信度阈值""" assert service.CONFIDENCE_THRESHOLD_LOW == 0.5 assert service.CONFIDENCE_THRESHOLD_HIGH == 0.8 def test_get_source_for_strategy(self, service): """测试策略到来源的映射""" assert service._get_source_for_strategy("rule") == SlotSource.RULE_EXTRACTED.value assert service._get_source_for_strategy("llm") == SlotSource.LLM_INFERRED.value assert service._get_source_for_strategy("user_input") == SlotSource.USER_CONFIRMED.value assert service._get_source_for_strategy("unknown") == "unknown" def test_get_confidence_for_strategy(self, service): """测试来源到置信度的映射""" assert service._get_confidence_for_strategy(SlotSource.USER_CONFIRMED.value) == 1.0 assert service._get_confidence_for_strategy(SlotSource.RULE_EXTRACTED.value) == 0.9 assert service._get_confidence_for_strategy(SlotSource.LLM_INFERRED.value) == 0.7 assert service._get_confidence_for_strategy("context") == 0.5 assert service._get_confidence_for_strategy(SlotSource.DEFAULT.value) == 0.3 def test_generate_confirmation_prompt(self, service): """测试生成确认提示""" prompt = service._generate_confirmation_prompt("region", "北京") assert "北京" in prompt assert "对吗" in prompt @pytest.mark.asyncio async def test_backfill_single_slot_success(self, service, mock_slot_manager): """测试单个槽位回填成功""" mock_slot_manager.write_slot.return_value = SlotWriteResult( success=True, slot_key="region", value="北京", ) with patch.object(service, '_get_state_aggregator') as mock_agg: mock_aggregator = AsyncMock() mock_aggregator.update_slot = AsyncMock() mock_agg.return_value = mock_aggregator result = await service.backfill_single_slot( slot_key="region", candidate_value="北京", source="user_confirmed", confidence=1.0, ) assert result.status == BackfillStatus.SUCCESS assert result.slot_key == "region" assert result.normalized_value == "北京" @pytest.mark.asyncio async def test_backfill_single_slot_validation_failed(self, service, mock_slot_manager): """测试单个槽位回填校验失败""" from app.services.mid.slot_validation_service import SlotValidationError mock_slot_manager.write_slot.return_value = SlotWriteResult( success=False, slot_key="region", error=SlotValidationError( slot_key="region", error_code="INVALID_VALUE", error_message="无效的地区", ), ask_back_prompt="请提供有效的地区", ) result = await service.backfill_single_slot( slot_key="region", candidate_value="无效地区", source="user_confirmed", confidence=1.0, ) assert result.status == BackfillStatus.VALIDATION_FAILED assert result.ask_back_prompt == "请提供有效的地区" @pytest.mark.asyncio async def test_backfill_single_slot_low_confidence(self, service, mock_slot_manager): """测试低置信度槽位需要确认""" mock_slot_manager.write_slot.return_value = SlotWriteResult( success=True, slot_key="region", value="北京", ) with patch.object(service, '_get_state_aggregator') as mock_agg: mock_aggregator = AsyncMock() mock_aggregator.update_slot = AsyncMock() mock_agg.return_value = mock_aggregator result = await service.backfill_single_slot( slot_key="region", candidate_value="北京", source="llm_inferred", confidence=0.4, ) assert result.status == BackfillStatus.NEEDS_CONFIRMATION assert result.confirmation_prompt is not None assert "北京" in result.confirmation_prompt @pytest.mark.asyncio async def test_backfill_multiple_slots(self, service, mock_slot_manager): """测试批量回填槽位""" mock_slot_manager.write_slot.side_effect = [ SlotWriteResult(success=True, slot_key="region", value="北京"), SlotWriteResult(success=True, slot_key="product", value="手机"), SlotWriteResult(success=False, slot_key="grade", error=MagicMock()), ] with patch.object(service, '_get_state_aggregator') as mock_agg: mock_aggregator = AsyncMock() mock_aggregator.update_slot = AsyncMock() mock_agg.return_value = mock_aggregator result = await service.backfill_multiple_slots( candidates={ "region": "北京", "product": "手机", "grade": "无效等级", }, source="user_confirmed", ) assert result.success_count == 2 assert result.failed_count == 1 @pytest.mark.asyncio async def test_confirm_low_confidence_slot_confirmed(self, service): """测试确认低置信度槽位 - 用户确认""" with patch.object(service, '_get_state_aggregator') as mock_agg: mock_aggregator = AsyncMock() mock_aggregator.update_slot = AsyncMock() mock_agg.return_value = mock_aggregator result = await service.confirm_low_confidence_slot( slot_key="region", confirmed=True, ) assert result.status == BackfillStatus.SUCCESS assert result.source == SlotSource.USER_CONFIRMED.value assert result.confidence == 1.0 @pytest.mark.asyncio async def test_confirm_low_confidence_slot_rejected(self, service, mock_slot_manager): """测试确认低置信度槽位 - 用户拒绝""" with patch.object(service, '_get_state_aggregator') as mock_agg: mock_aggregator = AsyncMock() mock_aggregator.clear_slot = AsyncMock() mock_agg.return_value = mock_aggregator result = await service.confirm_low_confidence_slot( slot_key="region", confirmed=False, ) assert result.status == BackfillStatus.VALIDATION_FAILED assert result.ask_back_prompt is not None class TestCreateSlotBackfillService: """create_slot_backfill_service 工厂函数测试""" def test_create(self): """测试创建服务实例""" mock_session = AsyncMock() service = create_slot_backfill_service( session=mock_session, tenant_id="tenant_1", session_id="session_1", ) assert isinstance(service, SlotBackfillService) assert service._tenant_id == "tenant_1" assert service._session_id == "session_1" class TestBackfillFromUserResponse: """从用户回复回填测试""" @pytest.fixture def service(self): """创建服务实例""" mock_session = AsyncMock() mock_slot_def_service = AsyncMock() service = SlotBackfillService( session=mock_session, tenant_id="tenant_1", session_id="session_1", ) service._slot_def_service = mock_slot_def_service return service @pytest.mark.asyncio async def test_backfill_from_user_response_success(self, service): """测试从用户回复成功提取并回填""" mock_slot_def = MagicMock() mock_slot_def.type = "string" mock_slot_def.validation_rule = None mock_slot_def.ask_back_prompt = "请提供地区" service._slot_def_service.get_slot_definition_by_key = AsyncMock( return_value=mock_slot_def ) with patch.object(service, '_extract_value') as mock_extract: mock_extract.return_value = StrategyChainResult( slot_key="region", success=True, final_value="北京", final_strategy="rule", ) with patch.object(service, 'backfill_single_slot') as mock_backfill: mock_backfill.return_value = BackfillResult( status=BackfillStatus.SUCCESS, slot_key="region", value="北京", ) result = await service.backfill_from_user_response( user_response="我想查询北京的产品", expected_slots=["region"], ) assert result.success_count == 1 @pytest.mark.asyncio async def test_backfill_from_user_response_no_definition(self, service): """测试槽位定义不存在""" service._slot_def_service.get_slot_definition_by_key = AsyncMock( return_value=None ) result = await service.backfill_from_user_response( user_response="我想查询北京的产品", expected_slots=["unknown_slot"], ) assert result.success_count == 0 assert result.failed_count == 0 @pytest.mark.asyncio async def test_backfill_from_user_response_extraction_failed(self, service): """测试提取失败""" mock_slot_def = MagicMock() mock_slot_def.type = "string" mock_slot_def.validation_rule = None mock_slot_def.ask_back_prompt = "请提供地区" service._slot_def_service.get_slot_definition_by_key = AsyncMock( return_value=mock_slot_def ) with patch.object(service, '_extract_value') as mock_extract: mock_extract.return_value = StrategyChainResult( slot_key="region", success=False, ) result = await service.backfill_from_user_response( user_response="我想查询产品", expected_slots=["region"], ) assert result.failed_count == 1 assert result.results[0].status == BackfillStatus.EXTRACTION_FAILED