309 lines
10 KiB
Python
309 lines
10 KiB
Python
"""
|
||
Tests for Dialogue API with Slot State Integration.
|
||
[AC-MRS-SLOT-META-03] 对话 API 与槽位状态集成测试
|
||
"""
|
||
|
||
import pytest
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
from app.api.mid.dialogue import _generate_ask_back_for_missing_slots
|
||
from app.models.mid.schemas import ExecutionMode, Segment, TraceInfo
|
||
from app.services.mid.slot_state_aggregator import SlotState
|
||
|
||
|
||
class TestGenerateAskBackResponse:
|
||
"""测试生成追问响应"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_generate_ask_back_with_prompt(self):
|
||
"""测试使用配置的 ask_back_prompt 生成追问"""
|
||
slot_state = SlotState()
|
||
missing_slots = [
|
||
{
|
||
"slot_key": "region",
|
||
"label": "地区",
|
||
"ask_back_prompt": "请问您在哪个地区?",
|
||
}
|
||
]
|
||
mock_session = AsyncMock()
|
||
|
||
response = await _generate_ask_back_for_missing_slots(
|
||
slot_state=slot_state,
|
||
missing_slots=missing_slots,
|
||
session=mock_session,
|
||
tenant_id="test_tenant",
|
||
)
|
||
|
||
assert response == "请问您在哪个地区?"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_generate_ask_back_generic(self):
|
||
"""测试使用通用模板生成追问"""
|
||
slot_state = SlotState()
|
||
missing_slots = [
|
||
{
|
||
"slot_key": "product_line",
|
||
"label": "产品线",
|
||
# 没有 ask_back_prompt
|
||
}
|
||
]
|
||
mock_session = AsyncMock()
|
||
|
||
response = await _generate_ask_back_for_missing_slots(
|
||
slot_state=slot_state,
|
||
missing_slots=missing_slots,
|
||
session=mock_session,
|
||
tenant_id="test_tenant",
|
||
)
|
||
|
||
assert "产品线" in response
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_generate_ask_back_empty_slots(self):
|
||
"""测试空缺失槽位列表"""
|
||
slot_state = SlotState()
|
||
missing_slots = []
|
||
mock_session = AsyncMock()
|
||
|
||
response = await _generate_ask_back_for_missing_slots(
|
||
slot_state=slot_state,
|
||
missing_slots=missing_slots,
|
||
session=mock_session,
|
||
tenant_id="test_tenant",
|
||
)
|
||
|
||
assert "更多信息" in response
|
||
|
||
|
||
class TestDialogueAskBackResponse:
|
||
"""测试对话追问响应"""
|
||
|
||
def test_dialogue_response_with_ask_back(self):
|
||
"""测试追问响应的结构"""
|
||
from app.models.mid.schemas import DialogueResponse
|
||
|
||
response = DialogueResponse(
|
||
segments=[Segment(text="请问您咨询的是哪个产品线?", delay_after=0)],
|
||
trace=TraceInfo(
|
||
mode=ExecutionMode.AGENT,
|
||
request_id="test_request_id",
|
||
generation_id="test_generation_id",
|
||
fallback_reason_code="missing_required_slots",
|
||
kb_tool_called=True,
|
||
kb_hit=False,
|
||
),
|
||
)
|
||
|
||
assert len(response.segments) == 1
|
||
assert "哪个产品线" in response.segments[0].text
|
||
assert response.trace.fallback_reason_code == "missing_required_slots"
|
||
assert response.trace.kb_tool_called is True
|
||
assert response.trace.kb_hit is False
|
||
|
||
|
||
class TestSlotStateAggregationFlow:
|
||
"""测试槽位状态聚合流程"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_memory_slots_included_in_state(self):
|
||
"""测试 memory_recall 的槽位被包含在状态中"""
|
||
from app.models.mid.schemas import MemorySlot, SlotSource
|
||
from app.services.mid.slot_state_aggregator import SlotStateAggregator
|
||
|
||
mock_session = AsyncMock()
|
||
aggregator = SlotStateAggregator(
|
||
session=mock_session,
|
||
tenant_id="test_tenant",
|
||
)
|
||
|
||
memory_slots = {
|
||
"product_line": MemorySlot(
|
||
key="product_line",
|
||
value="vip_course",
|
||
source=SlotSource.USER_CONFIRMED,
|
||
confidence=1.0,
|
||
)
|
||
}
|
||
|
||
with patch.object(
|
||
aggregator._slot_def_service,
|
||
"list_slot_definitions",
|
||
return_value=[],
|
||
):
|
||
state = await aggregator.aggregate(
|
||
memory_slots=memory_slots,
|
||
current_input_slots=None,
|
||
context=None,
|
||
)
|
||
|
||
assert "product_line" in state.filled_slots
|
||
assert state.filled_slots["product_line"] == "vip_course"
|
||
assert state.slot_sources["product_line"] == "user_confirmed"
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_missing_slots_identified(self):
|
||
"""测试缺失的必填槽位被正确识别"""
|
||
from unittest.mock import MagicMock
|
||
from app.services.mid.slot_state_aggregator import SlotStateAggregator
|
||
|
||
mock_session = AsyncMock()
|
||
aggregator = SlotStateAggregator(
|
||
session=mock_session,
|
||
tenant_id="test_tenant",
|
||
)
|
||
|
||
# 模拟一个 required 的槽位定义
|
||
mock_slot_def = MagicMock()
|
||
mock_slot_def.slot_key = "region"
|
||
mock_slot_def.required = True
|
||
mock_slot_def.ask_back_prompt = "请问您在哪个地区?"
|
||
mock_slot_def.linked_field_id = None
|
||
|
||
with patch.object(
|
||
aggregator._slot_def_service,
|
||
"list_slot_definitions",
|
||
return_value=[mock_slot_def],
|
||
):
|
||
state = await aggregator.aggregate(
|
||
memory_slots={},
|
||
current_input_slots=None,
|
||
context=None,
|
||
)
|
||
|
||
assert len(state.missing_required_slots) == 1
|
||
assert state.missing_required_slots[0]["slot_key"] == "region"
|
||
assert state.missing_required_slots[0]["ask_back_prompt"] == "请问您在哪个地区?"
|
||
|
||
|
||
class TestSlotMetadataLinkage:
|
||
"""测试槽位与元数据关联"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_slot_to_field_mapping(self):
|
||
"""测试槽位到元数据字段的映射"""
|
||
from unittest.mock import MagicMock, patch
|
||
from app.services.mid.slot_state_aggregator import SlotStateAggregator
|
||
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
|
||
|
||
mock_session = AsyncMock()
|
||
aggregator = SlotStateAggregator(
|
||
session=mock_session,
|
||
tenant_id="test_tenant",
|
||
)
|
||
|
||
# 模拟槽位定义(带 linked_field_id)
|
||
mock_slot_def = MagicMock()
|
||
mock_slot_def.slot_key = "product"
|
||
mock_slot_def.linked_field_id = "field-uuid-123"
|
||
mock_slot_def.required = False
|
||
mock_slot_def.type = "string"
|
||
mock_slot_def.options = None
|
||
|
||
# 模拟关联的元数据字段
|
||
mock_field = MagicMock()
|
||
mock_field.field_key = "product_line"
|
||
mock_field.label = "产品线"
|
||
mock_field.type = "string"
|
||
mock_field.required = False
|
||
mock_field.options = None
|
||
|
||
with patch.object(
|
||
aggregator._slot_def_service,
|
||
"list_slot_definitions",
|
||
return_value=[mock_slot_def],
|
||
):
|
||
with patch.object(
|
||
MetadataFieldDefinitionService,
|
||
"get_field_definition",
|
||
return_value=mock_field,
|
||
):
|
||
state = await aggregator.aggregate(
|
||
memory_slots={},
|
||
current_input_slots=None,
|
||
context=None,
|
||
)
|
||
|
||
# 验证映射已建立
|
||
assert state.slot_to_field_map.get("product") == "product_line"
|
||
|
||
|
||
class TestBackwardCompatibility:
|
||
"""测试向后兼容性"""
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_kb_search_without_slot_state(self):
|
||
"""测试不使用 slot_state 时 KB 检索仍然工作"""
|
||
from app.services.mid.kb_search_dynamic_tool import (
|
||
KbSearchDynamicConfig,
|
||
KbSearchDynamicTool,
|
||
)
|
||
from app.services.mid.metadata_filter_builder import MetadataFilterBuilder
|
||
|
||
mock_session = AsyncMock()
|
||
kb_tool = KbSearchDynamicTool(
|
||
session=mock_session,
|
||
config=KbSearchDynamicConfig(enabled=True),
|
||
)
|
||
|
||
# 模拟 filter_builder 返回空结果
|
||
with patch.object(
|
||
MetadataFilterBuilder,
|
||
"_get_filterable_fields",
|
||
return_value=[],
|
||
):
|
||
with patch.object(
|
||
kb_tool,
|
||
"_retrieve_with_timeout",
|
||
return_value=[],
|
||
):
|
||
result = await kb_tool.execute(
|
||
query="退款政策",
|
||
tenant_id="test_tenant",
|
||
context={},
|
||
slot_state=None, # 不提供 slot_state
|
||
)
|
||
|
||
# 应该成功执行
|
||
assert result.success is True
|
||
assert result.fallback_reason_code is None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_legacy_context_filter(self):
|
||
"""测试使用传统 context 构建过滤器"""
|
||
from app.services.mid.kb_search_dynamic_tool import (
|
||
KbSearchDynamicConfig,
|
||
KbSearchDynamicTool,
|
||
)
|
||
from app.services.mid.metadata_filter_builder import MetadataFilterBuilder
|
||
|
||
mock_session = AsyncMock()
|
||
kb_tool = KbSearchDynamicTool(
|
||
session=mock_session,
|
||
config=KbSearchDynamicConfig(enabled=True),
|
||
)
|
||
|
||
# 使用简单 context
|
||
context = {"product_line": "vip_course", "region": "beijing"}
|
||
|
||
with patch.object(
|
||
MetadataFilterBuilder,
|
||
"_get_filterable_fields",
|
||
return_value=[],
|
||
):
|
||
with patch.object(
|
||
kb_tool,
|
||
"_retrieve_with_timeout",
|
||
return_value=[],
|
||
):
|
||
result = await kb_tool.execute(
|
||
query="退款政策",
|
||
tenant_id="test_tenant",
|
||
context=context,
|
||
slot_state=None,
|
||
)
|
||
|
||
# 应该成功执行
|
||
assert result.success is True
|
||
# 简单 context 应该直接使用作为 filter
|
||
assert result.applied_filter.get("product_line") == "vip_course"
|