ai-robot-core/ai-service/tests/test_slot_extraction_integr...

336 lines
12 KiB
Python

"""
Tests for Slot Extraction Integration.
[AC-MRS-SLOT-EXTRACT-01] slot extraction 集成测试
"""
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.models.mid.schemas import SlotSource
from app.services.mid.slot_extraction_integration import (
ExtractionResult,
ExtractionTrace,
SlotExtractionIntegration,
integrate_slot_extraction,
)
from app.services.mid.slot_strategy_executor import (
StrategyChainResult,
StrategyStepResult,
)
class TestExtractionTrace:
"""ExtractionTrace 测试"""
def test_init(self):
"""测试初始化"""
trace = ExtractionTrace(slot_key="region")
assert trace.slot_key == "region"
assert trace.strategy is None
assert trace.validation_passed is False
def test_to_dict(self):
"""测试转换为字典"""
trace = ExtractionTrace(
slot_key="region",
strategy="rule",
extracted_value="北京",
validation_passed=True,
final_value="北京",
execution_time_ms=10.5,
)
d = trace.to_dict()
assert d["slot_key"] == "region"
assert d["strategy"] == "rule"
assert d["extracted_value"] == "北京"
assert d["validation_passed"] is True
class TestExtractionResult:
"""ExtractionResult 测试"""
def test_init(self):
"""测试初始化"""
result = ExtractionResult()
assert result.success is False
assert result.extracted_slots == {}
assert result.failed_slots == []
def test_to_dict(self):
"""测试转换为字典"""
result = ExtractionResult(
success=True,
extracted_slots={"region": "北京"},
failed_slots=["product"],
traces=[ExtractionTrace(slot_key="region")],
total_execution_time_ms=50.0,
ask_back_triggered=True,
ask_back_prompts=["请提供产品信息"],
)
d = result.to_dict()
assert d["success"] is True
assert d["extracted_slots"] == {"region": "北京"}
assert d["failed_slots"] == ["product"]
assert d["ask_back_triggered"] is True
class TestSlotExtractionIntegration:
"""SlotExtractionIntegration 测试"""
@pytest.fixture
def mock_session(self):
"""创建 mock session"""
return AsyncMock()
@pytest.fixture
def integration(self, mock_session):
"""创建集成实例"""
return SlotExtractionIntegration(
session=mock_session,
tenant_id="tenant_1",
session_id="session_1",
)
def test_default_strategies(self, integration):
"""测试默认策略"""
assert integration.DEFAULT_STRATEGIES == ["rule", "llm"]
def test_get_source_for_strategy(self, integration):
"""测试策略到来源的映射"""
assert integration._get_source_for_strategy("rule") == SlotSource.RULE_EXTRACTED.value
assert integration._get_source_for_strategy("llm") == SlotSource.LLM_INFERRED.value
assert integration._get_source_for_strategy("user_input") == SlotSource.USER_CONFIRMED.value
def test_get_confidence_for_source(self, integration):
"""测试来源到置信度的映射"""
assert integration._get_confidence_for_source(SlotSource.USER_CONFIRMED.value) == 1.0
assert integration._get_confidence_for_source(SlotSource.RULE_EXTRACTED.value) == 0.9
assert integration._get_confidence_for_source(SlotSource.LLM_INFERRED.value) == 0.7
@pytest.mark.asyncio
async def test_extract_and_fill_no_target_slots(self, integration):
"""测试没有目标槽位"""
result = await integration.extract_and_fill(
user_input="测试输入",
target_slots=[],
)
assert result.success is True
assert result.extracted_slots == {}
@pytest.mark.asyncio
async def test_extract_and_fill_slot_not_found(self, integration):
"""测试槽位定义不存在"""
integration._slot_def_service.get_slot_definition_by_key = AsyncMock(
return_value=None
)
integration._slot_manager.get_ask_back_prompt = AsyncMock(return_value=None)
result = await integration.extract_and_fill(
user_input="测试输入",
target_slots=["unknown_slot"],
)
assert result.success is False
assert "unknown_slot" in result.failed_slots
assert result.traces[0].failure_reason == "Slot definition not found"
@pytest.mark.asyncio
async def test_extract_and_fill_extraction_success(self, integration):
"""测试提取成功"""
mock_slot_def = MagicMock()
mock_slot_def.type = "string"
mock_slot_def.validation_rule = None
mock_slot_def.ask_back_prompt = "请提供地区"
integration._slot_def_service.get_slot_definition_by_key = AsyncMock(
return_value=mock_slot_def
)
with patch.object(integration._strategy_executor, 'execute_chain') as mock_chain:
mock_chain.return_value = StrategyChainResult(
slot_key="region",
success=True,
final_value="北京",
final_strategy="rule",
)
with patch.object(integration, '_get_backfill_service') as mock_backfill_svc:
mock_backfill = AsyncMock()
mock_backfill.backfill_single_slot = AsyncMock(
return_value=MagicMock(
is_success=lambda: True,
normalized_value="北京",
error_message=None,
)
)
mock_backfill_svc.return_value = mock_backfill
with patch.object(integration, '_save_extracted_slots', new_callable=AsyncMock):
result = await integration.extract_and_fill(
user_input="我想查询北京的产品",
target_slots=["region"],
)
assert result.success is True
assert "region" in result.extracted_slots
assert result.extracted_slots["region"] == "北京"
@pytest.mark.asyncio
async def test_extract_and_fill_extraction_failed(self, integration):
"""测试提取失败"""
mock_slot_def = MagicMock()
mock_slot_def.type = "string"
mock_slot_def.validation_rule = None
mock_slot_def.ask_back_prompt = "请提供地区"
integration._slot_def_service.get_slot_definition_by_key = AsyncMock(
return_value=mock_slot_def
)
with patch.object(integration._strategy_executor, 'execute_chain') as mock_chain:
mock_chain.return_value = StrategyChainResult(
slot_key="region",
success=False,
steps=[
StrategyStepResult(
strategy="rule",
success=False,
failure_reason="无法提取",
)
],
)
integration._slot_manager.get_ask_back_prompt = AsyncMock(
return_value="请提供地区"
)
result = await integration.extract_and_fill(
user_input="测试输入",
target_slots=["region"],
)
assert result.success is False
assert "region" in result.failed_slots
assert result.ask_back_triggered is True
assert "请提供地区" in result.ask_back_prompts
@pytest.mark.asyncio
async def test_get_missing_required_slots_from_state(self, integration):
"""测试从状态获取缺失槽位"""
from app.services.mid.slot_state_aggregator import SlotState
slot_state = SlotState()
slot_state.missing_required_slots = [
{"slot_key": "region"},
{"slot_key": "product"},
]
result = await integration._get_missing_required_slots(slot_state)
assert "region" in result
assert "product" in result
@pytest.mark.asyncio
async def test_get_missing_required_slots_from_db(self, integration):
"""测试从数据库获取缺失槽位"""
mock_defs = [
MagicMock(slot_key="region"),
MagicMock(slot_key="product"),
]
integration._slot_def_service.list_slot_definitions = AsyncMock(
return_value=mock_defs
)
result = await integration._get_missing_required_slots(None)
assert "region" in result
assert "product" in result
class TestIntegrateSlotExtraction:
"""integrate_slot_extraction 便捷函数测试"""
@pytest.mark.asyncio
async def test_integrate(self):
"""测试便捷函数"""
mock_session = AsyncMock()
with patch('app.services.mid.slot_extraction_integration.SlotExtractionIntegration') as mock_cls:
mock_instance = AsyncMock()
mock_instance.extract_and_fill = AsyncMock(
return_value=ExtractionResult(success=True, extracted_slots={"region": "北京"})
)
mock_cls.return_value = mock_instance
result = await integrate_slot_extraction(
session=mock_session,
tenant_id="tenant_1",
session_id="session_1",
user_input="我想查询北京的产品",
)
assert result.success is True
assert "region" in result.extracted_slots
class TestExtractionTraceFlow:
"""提取追踪流程测试"""
@pytest.fixture
def integration(self):
"""创建集成实例"""
mock_session = AsyncMock()
return SlotExtractionIntegration(
session=mock_session,
tenant_id="tenant_1",
session_id="session_1",
)
@pytest.mark.asyncio
async def test_full_extraction_flow(self, integration):
"""测试完整提取流程"""
mock_slot_def = MagicMock()
mock_slot_def.type = "string"
mock_slot_def.validation_rule = None
mock_slot_def.ask_back_prompt = None
integration._slot_def_service.get_slot_definition_by_key = AsyncMock(
return_value=mock_slot_def
)
with patch.object(integration._strategy_executor, 'execute_chain') as mock_chain:
mock_chain.return_value = StrategyChainResult(
slot_key="region",
success=True,
final_value="北京",
final_strategy="rule",
)
with patch.object(integration, '_get_backfill_service') as mock_backfill_svc:
mock_backfill = AsyncMock()
mock_backfill.backfill_single_slot = AsyncMock(
return_value=MagicMock(
is_success=lambda: True,
normalized_value="北京",
error_message=None,
)
)
mock_backfill_svc.return_value = mock_backfill
with patch.object(integration, '_save_extracted_slots', new_callable=AsyncMock):
result = await integration.extract_and_fill(
user_input="北京",
target_slots=["region"],
)
assert len(result.traces) == 1
trace = result.traces[0]
assert trace.slot_key == "region"
assert trace.strategy == "rule"
assert trace.extracted_value == "北京"
assert trace.validation_passed is True
assert trace.final_value == "北京"