336 lines
12 KiB
Python
336 lines
12 KiB
Python
"""
|
|
Tests for Slot Extraction Integration.
|
|
[AC-MRS-SLOT-EXTRACT-01] slot extraction 集成测试
|
|
"""
|
|
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from app.models.mid.schemas import SlotSource
|
|
from app.services.mid.slot_extraction_integration import (
|
|
ExtractionResult,
|
|
ExtractionTrace,
|
|
SlotExtractionIntegration,
|
|
integrate_slot_extraction,
|
|
)
|
|
from app.services.mid.slot_strategy_executor import (
|
|
StrategyChainResult,
|
|
StrategyStepResult,
|
|
)
|
|
|
|
|
|
class TestExtractionTrace:
|
|
"""ExtractionTrace 测试"""
|
|
|
|
def test_init(self):
|
|
"""测试初始化"""
|
|
trace = ExtractionTrace(slot_key="region")
|
|
assert trace.slot_key == "region"
|
|
assert trace.strategy is None
|
|
assert trace.validation_passed is False
|
|
|
|
def test_to_dict(self):
|
|
"""测试转换为字典"""
|
|
trace = ExtractionTrace(
|
|
slot_key="region",
|
|
strategy="rule",
|
|
extracted_value="北京",
|
|
validation_passed=True,
|
|
final_value="北京",
|
|
execution_time_ms=10.5,
|
|
)
|
|
d = trace.to_dict()
|
|
assert d["slot_key"] == "region"
|
|
assert d["strategy"] == "rule"
|
|
assert d["extracted_value"] == "北京"
|
|
assert d["validation_passed"] is True
|
|
|
|
|
|
class TestExtractionResult:
|
|
"""ExtractionResult 测试"""
|
|
|
|
def test_init(self):
|
|
"""测试初始化"""
|
|
result = ExtractionResult()
|
|
assert result.success is False
|
|
assert result.extracted_slots == {}
|
|
assert result.failed_slots == []
|
|
|
|
def test_to_dict(self):
|
|
"""测试转换为字典"""
|
|
result = ExtractionResult(
|
|
success=True,
|
|
extracted_slots={"region": "北京"},
|
|
failed_slots=["product"],
|
|
traces=[ExtractionTrace(slot_key="region")],
|
|
total_execution_time_ms=50.0,
|
|
ask_back_triggered=True,
|
|
ask_back_prompts=["请提供产品信息"],
|
|
)
|
|
d = result.to_dict()
|
|
assert d["success"] is True
|
|
assert d["extracted_slots"] == {"region": "北京"}
|
|
assert d["failed_slots"] == ["product"]
|
|
assert d["ask_back_triggered"] is True
|
|
|
|
|
|
class TestSlotExtractionIntegration:
|
|
"""SlotExtractionIntegration 测试"""
|
|
|
|
@pytest.fixture
|
|
def mock_session(self):
|
|
"""创建 mock session"""
|
|
return AsyncMock()
|
|
|
|
@pytest.fixture
|
|
def integration(self, mock_session):
|
|
"""创建集成实例"""
|
|
return SlotExtractionIntegration(
|
|
session=mock_session,
|
|
tenant_id="tenant_1",
|
|
session_id="session_1",
|
|
)
|
|
|
|
def test_default_strategies(self, integration):
|
|
"""测试默认策略"""
|
|
assert integration.DEFAULT_STRATEGIES == ["rule", "llm"]
|
|
|
|
def test_get_source_for_strategy(self, integration):
|
|
"""测试策略到来源的映射"""
|
|
assert integration._get_source_for_strategy("rule") == SlotSource.RULE_EXTRACTED.value
|
|
assert integration._get_source_for_strategy("llm") == SlotSource.LLM_INFERRED.value
|
|
assert integration._get_source_for_strategy("user_input") == SlotSource.USER_CONFIRMED.value
|
|
|
|
def test_get_confidence_for_source(self, integration):
|
|
"""测试来源到置信度的映射"""
|
|
assert integration._get_confidence_for_source(SlotSource.USER_CONFIRMED.value) == 1.0
|
|
assert integration._get_confidence_for_source(SlotSource.RULE_EXTRACTED.value) == 0.9
|
|
assert integration._get_confidence_for_source(SlotSource.LLM_INFERRED.value) == 0.7
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_and_fill_no_target_slots(self, integration):
|
|
"""测试没有目标槽位"""
|
|
result = await integration.extract_and_fill(
|
|
user_input="测试输入",
|
|
target_slots=[],
|
|
)
|
|
|
|
assert result.success is True
|
|
assert result.extracted_slots == {}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_and_fill_slot_not_found(self, integration):
|
|
"""测试槽位定义不存在"""
|
|
integration._slot_def_service.get_slot_definition_by_key = AsyncMock(
|
|
return_value=None
|
|
)
|
|
integration._slot_manager.get_ask_back_prompt = AsyncMock(return_value=None)
|
|
|
|
result = await integration.extract_and_fill(
|
|
user_input="测试输入",
|
|
target_slots=["unknown_slot"],
|
|
)
|
|
|
|
assert result.success is False
|
|
assert "unknown_slot" in result.failed_slots
|
|
assert result.traces[0].failure_reason == "Slot definition not found"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_and_fill_extraction_success(self, integration):
|
|
"""测试提取成功"""
|
|
mock_slot_def = MagicMock()
|
|
mock_slot_def.type = "string"
|
|
mock_slot_def.validation_rule = None
|
|
mock_slot_def.ask_back_prompt = "请提供地区"
|
|
|
|
integration._slot_def_service.get_slot_definition_by_key = AsyncMock(
|
|
return_value=mock_slot_def
|
|
)
|
|
|
|
with patch.object(integration._strategy_executor, 'execute_chain') as mock_chain:
|
|
mock_chain.return_value = StrategyChainResult(
|
|
slot_key="region",
|
|
success=True,
|
|
final_value="北京",
|
|
final_strategy="rule",
|
|
)
|
|
|
|
with patch.object(integration, '_get_backfill_service') as mock_backfill_svc:
|
|
mock_backfill = AsyncMock()
|
|
mock_backfill.backfill_single_slot = AsyncMock(
|
|
return_value=MagicMock(
|
|
is_success=lambda: True,
|
|
normalized_value="北京",
|
|
error_message=None,
|
|
)
|
|
)
|
|
mock_backfill_svc.return_value = mock_backfill
|
|
|
|
with patch.object(integration, '_save_extracted_slots', new_callable=AsyncMock):
|
|
result = await integration.extract_and_fill(
|
|
user_input="我想查询北京的产品",
|
|
target_slots=["region"],
|
|
)
|
|
|
|
assert result.success is True
|
|
assert "region" in result.extracted_slots
|
|
assert result.extracted_slots["region"] == "北京"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_and_fill_extraction_failed(self, integration):
|
|
"""测试提取失败"""
|
|
mock_slot_def = MagicMock()
|
|
mock_slot_def.type = "string"
|
|
mock_slot_def.validation_rule = None
|
|
mock_slot_def.ask_back_prompt = "请提供地区"
|
|
|
|
integration._slot_def_service.get_slot_definition_by_key = AsyncMock(
|
|
return_value=mock_slot_def
|
|
)
|
|
|
|
with patch.object(integration._strategy_executor, 'execute_chain') as mock_chain:
|
|
mock_chain.return_value = StrategyChainResult(
|
|
slot_key="region",
|
|
success=False,
|
|
steps=[
|
|
StrategyStepResult(
|
|
strategy="rule",
|
|
success=False,
|
|
failure_reason="无法提取",
|
|
)
|
|
],
|
|
)
|
|
|
|
integration._slot_manager.get_ask_back_prompt = AsyncMock(
|
|
return_value="请提供地区"
|
|
)
|
|
|
|
result = await integration.extract_and_fill(
|
|
user_input="测试输入",
|
|
target_slots=["region"],
|
|
)
|
|
|
|
assert result.success is False
|
|
assert "region" in result.failed_slots
|
|
assert result.ask_back_triggered is True
|
|
assert "请提供地区" in result.ask_back_prompts
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_missing_required_slots_from_state(self, integration):
|
|
"""测试从状态获取缺失槽位"""
|
|
from app.services.mid.slot_state_aggregator import SlotState
|
|
|
|
slot_state = SlotState()
|
|
slot_state.missing_required_slots = [
|
|
{"slot_key": "region"},
|
|
{"slot_key": "product"},
|
|
]
|
|
|
|
result = await integration._get_missing_required_slots(slot_state)
|
|
|
|
assert "region" in result
|
|
assert "product" in result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_missing_required_slots_from_db(self, integration):
|
|
"""测试从数据库获取缺失槽位"""
|
|
mock_defs = [
|
|
MagicMock(slot_key="region"),
|
|
MagicMock(slot_key="product"),
|
|
]
|
|
|
|
integration._slot_def_service.list_slot_definitions = AsyncMock(
|
|
return_value=mock_defs
|
|
)
|
|
|
|
result = await integration._get_missing_required_slots(None)
|
|
|
|
assert "region" in result
|
|
assert "product" in result
|
|
|
|
|
|
class TestIntegrateSlotExtraction:
|
|
"""integrate_slot_extraction 便捷函数测试"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_integrate(self):
|
|
"""测试便捷函数"""
|
|
mock_session = AsyncMock()
|
|
|
|
with patch('app.services.mid.slot_extraction_integration.SlotExtractionIntegration') as mock_cls:
|
|
mock_instance = AsyncMock()
|
|
mock_instance.extract_and_fill = AsyncMock(
|
|
return_value=ExtractionResult(success=True, extracted_slots={"region": "北京"})
|
|
)
|
|
mock_cls.return_value = mock_instance
|
|
|
|
result = await integrate_slot_extraction(
|
|
session=mock_session,
|
|
tenant_id="tenant_1",
|
|
session_id="session_1",
|
|
user_input="我想查询北京的产品",
|
|
)
|
|
|
|
assert result.success is True
|
|
assert "region" in result.extracted_slots
|
|
|
|
|
|
class TestExtractionTraceFlow:
|
|
"""提取追踪流程测试"""
|
|
|
|
@pytest.fixture
|
|
def integration(self):
|
|
"""创建集成实例"""
|
|
mock_session = AsyncMock()
|
|
return SlotExtractionIntegration(
|
|
session=mock_session,
|
|
tenant_id="tenant_1",
|
|
session_id="session_1",
|
|
)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_full_extraction_flow(self, integration):
|
|
"""测试完整提取流程"""
|
|
mock_slot_def = MagicMock()
|
|
mock_slot_def.type = "string"
|
|
mock_slot_def.validation_rule = None
|
|
mock_slot_def.ask_back_prompt = None
|
|
|
|
integration._slot_def_service.get_slot_definition_by_key = AsyncMock(
|
|
return_value=mock_slot_def
|
|
)
|
|
|
|
with patch.object(integration._strategy_executor, 'execute_chain') as mock_chain:
|
|
mock_chain.return_value = StrategyChainResult(
|
|
slot_key="region",
|
|
success=True,
|
|
final_value="北京",
|
|
final_strategy="rule",
|
|
)
|
|
|
|
with patch.object(integration, '_get_backfill_service') as mock_backfill_svc:
|
|
mock_backfill = AsyncMock()
|
|
mock_backfill.backfill_single_slot = AsyncMock(
|
|
return_value=MagicMock(
|
|
is_success=lambda: True,
|
|
normalized_value="北京",
|
|
error_message=None,
|
|
)
|
|
)
|
|
mock_backfill_svc.return_value = mock_backfill
|
|
|
|
with patch.object(integration, '_save_extracted_slots', new_callable=AsyncMock):
|
|
result = await integration.extract_and_fill(
|
|
user_input="北京",
|
|
target_slots=["region"],
|
|
)
|
|
|
|
assert len(result.traces) == 1
|
|
trace = result.traces[0]
|
|
assert trace.slot_key == "region"
|
|
assert trace.strategy == "rule"
|
|
assert trace.extracted_value == "北京"
|
|
assert trace.validation_passed is True
|
|
assert trace.final_value == "北京"
|