309 lines
10 KiB
Python
309 lines
10 KiB
Python
|
|
"""
|
|||
|
|
Tests for Dialogue API with Slot State Integration.
|
|||
|
|
[AC-MRS-SLOT-META-03] 对话 API 与槽位状态集成测试
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import pytest
|
|||
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
|
|
|||
|
|
from app.api.mid.dialogue import _generate_ask_back_for_missing_slots
|
|||
|
|
from app.models.mid.schemas import ExecutionMode, Segment, TraceInfo
|
|||
|
|
from app.services.mid.slot_state_aggregator import SlotState
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestGenerateAskBackResponse:
|
|||
|
|
"""测试生成追问响应"""
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_generate_ask_back_with_prompt(self):
|
|||
|
|
"""测试使用配置的 ask_back_prompt 生成追问"""
|
|||
|
|
slot_state = SlotState()
|
|||
|
|
missing_slots = [
|
|||
|
|
{
|
|||
|
|
"slot_key": "region",
|
|||
|
|
"label": "地区",
|
|||
|
|
"ask_back_prompt": "请问您在哪个地区?",
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
mock_session = AsyncMock()
|
|||
|
|
|
|||
|
|
response = await _generate_ask_back_for_missing_slots(
|
|||
|
|
slot_state=slot_state,
|
|||
|
|
missing_slots=missing_slots,
|
|||
|
|
session=mock_session,
|
|||
|
|
tenant_id="test_tenant",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert response == "请问您在哪个地区?"
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_generate_ask_back_generic(self):
|
|||
|
|
"""测试使用通用模板生成追问"""
|
|||
|
|
slot_state = SlotState()
|
|||
|
|
missing_slots = [
|
|||
|
|
{
|
|||
|
|
"slot_key": "product_line",
|
|||
|
|
"label": "产品线",
|
|||
|
|
# 没有 ask_back_prompt
|
|||
|
|
}
|
|||
|
|
]
|
|||
|
|
mock_session = AsyncMock()
|
|||
|
|
|
|||
|
|
response = await _generate_ask_back_for_missing_slots(
|
|||
|
|
slot_state=slot_state,
|
|||
|
|
missing_slots=missing_slots,
|
|||
|
|
session=mock_session,
|
|||
|
|
tenant_id="test_tenant",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert "产品线" in response
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_generate_ask_back_empty_slots(self):
|
|||
|
|
"""测试空缺失槽位列表"""
|
|||
|
|
slot_state = SlotState()
|
|||
|
|
missing_slots = []
|
|||
|
|
mock_session = AsyncMock()
|
|||
|
|
|
|||
|
|
response = await _generate_ask_back_for_missing_slots(
|
|||
|
|
slot_state=slot_state,
|
|||
|
|
missing_slots=missing_slots,
|
|||
|
|
session=mock_session,
|
|||
|
|
tenant_id="test_tenant",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert "更多信息" in response
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestDialogueAskBackResponse:
|
|||
|
|
"""测试对话追问响应"""
|
|||
|
|
|
|||
|
|
def test_dialogue_response_with_ask_back(self):
|
|||
|
|
"""测试追问响应的结构"""
|
|||
|
|
from app.models.mid.schemas import DialogueResponse
|
|||
|
|
|
|||
|
|
response = DialogueResponse(
|
|||
|
|
segments=[Segment(text="请问您咨询的是哪个产品线?", delay_after=0)],
|
|||
|
|
trace=TraceInfo(
|
|||
|
|
mode=ExecutionMode.AGENT,
|
|||
|
|
request_id="test_request_id",
|
|||
|
|
generation_id="test_generation_id",
|
|||
|
|
fallback_reason_code="missing_required_slots",
|
|||
|
|
kb_tool_called=True,
|
|||
|
|
kb_hit=False,
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert len(response.segments) == 1
|
|||
|
|
assert "哪个产品线" in response.segments[0].text
|
|||
|
|
assert response.trace.fallback_reason_code == "missing_required_slots"
|
|||
|
|
assert response.trace.kb_tool_called is True
|
|||
|
|
assert response.trace.kb_hit is False
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestSlotStateAggregationFlow:
|
|||
|
|
"""测试槽位状态聚合流程"""
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_memory_slots_included_in_state(self):
|
|||
|
|
"""测试 memory_recall 的槽位被包含在状态中"""
|
|||
|
|
from app.models.mid.schemas import MemorySlot, SlotSource
|
|||
|
|
from app.services.mid.slot_state_aggregator import SlotStateAggregator
|
|||
|
|
|
|||
|
|
mock_session = AsyncMock()
|
|||
|
|
aggregator = SlotStateAggregator(
|
|||
|
|
session=mock_session,
|
|||
|
|
tenant_id="test_tenant",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
memory_slots = {
|
|||
|
|
"product_line": MemorySlot(
|
|||
|
|
key="product_line",
|
|||
|
|
value="vip_course",
|
|||
|
|
source=SlotSource.USER_CONFIRMED,
|
|||
|
|
confidence=1.0,
|
|||
|
|
)
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
with patch.object(
|
|||
|
|
aggregator._slot_def_service,
|
|||
|
|
"list_slot_definitions",
|
|||
|
|
return_value=[],
|
|||
|
|
):
|
|||
|
|
state = await aggregator.aggregate(
|
|||
|
|
memory_slots=memory_slots,
|
|||
|
|
current_input_slots=None,
|
|||
|
|
context=None,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert "product_line" in state.filled_slots
|
|||
|
|
assert state.filled_slots["product_line"] == "vip_course"
|
|||
|
|
assert state.slot_sources["product_line"] == "user_confirmed"
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_missing_slots_identified(self):
|
|||
|
|
"""测试缺失的必填槽位被正确识别"""
|
|||
|
|
from unittest.mock import MagicMock
|
|||
|
|
from app.services.mid.slot_state_aggregator import SlotStateAggregator
|
|||
|
|
|
|||
|
|
mock_session = AsyncMock()
|
|||
|
|
aggregator = SlotStateAggregator(
|
|||
|
|
session=mock_session,
|
|||
|
|
tenant_id="test_tenant",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 模拟一个 required 的槽位定义
|
|||
|
|
mock_slot_def = MagicMock()
|
|||
|
|
mock_slot_def.slot_key = "region"
|
|||
|
|
mock_slot_def.required = True
|
|||
|
|
mock_slot_def.ask_back_prompt = "请问您在哪个地区?"
|
|||
|
|
mock_slot_def.linked_field_id = None
|
|||
|
|
|
|||
|
|
with patch.object(
|
|||
|
|
aggregator._slot_def_service,
|
|||
|
|
"list_slot_definitions",
|
|||
|
|
return_value=[mock_slot_def],
|
|||
|
|
):
|
|||
|
|
state = await aggregator.aggregate(
|
|||
|
|
memory_slots={},
|
|||
|
|
current_input_slots=None,
|
|||
|
|
context=None,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
assert len(state.missing_required_slots) == 1
|
|||
|
|
assert state.missing_required_slots[0]["slot_key"] == "region"
|
|||
|
|
assert state.missing_required_slots[0]["ask_back_prompt"] == "请问您在哪个地区?"
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestSlotMetadataLinkage:
|
|||
|
|
"""测试槽位与元数据关联"""
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_slot_to_field_mapping(self):
|
|||
|
|
"""测试槽位到元数据字段的映射"""
|
|||
|
|
from unittest.mock import MagicMock, patch
|
|||
|
|
from app.services.mid.slot_state_aggregator import SlotStateAggregator
|
|||
|
|
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
|
|||
|
|
|
|||
|
|
mock_session = AsyncMock()
|
|||
|
|
aggregator = SlotStateAggregator(
|
|||
|
|
session=mock_session,
|
|||
|
|
tenant_id="test_tenant",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 模拟槽位定义(带 linked_field_id)
|
|||
|
|
mock_slot_def = MagicMock()
|
|||
|
|
mock_slot_def.slot_key = "product"
|
|||
|
|
mock_slot_def.linked_field_id = "field-uuid-123"
|
|||
|
|
mock_slot_def.required = False
|
|||
|
|
mock_slot_def.type = "string"
|
|||
|
|
mock_slot_def.options = None
|
|||
|
|
|
|||
|
|
# 模拟关联的元数据字段
|
|||
|
|
mock_field = MagicMock()
|
|||
|
|
mock_field.field_key = "product_line"
|
|||
|
|
mock_field.label = "产品线"
|
|||
|
|
mock_field.type = "string"
|
|||
|
|
mock_field.required = False
|
|||
|
|
mock_field.options = None
|
|||
|
|
|
|||
|
|
with patch.object(
|
|||
|
|
aggregator._slot_def_service,
|
|||
|
|
"list_slot_definitions",
|
|||
|
|
return_value=[mock_slot_def],
|
|||
|
|
):
|
|||
|
|
with patch.object(
|
|||
|
|
MetadataFieldDefinitionService,
|
|||
|
|
"get_field_definition",
|
|||
|
|
return_value=mock_field,
|
|||
|
|
):
|
|||
|
|
state = await aggregator.aggregate(
|
|||
|
|
memory_slots={},
|
|||
|
|
current_input_slots=None,
|
|||
|
|
context=None,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 验证映射已建立
|
|||
|
|
assert state.slot_to_field_map.get("product") == "product_line"
|
|||
|
|
|
|||
|
|
|
|||
|
|
class TestBackwardCompatibility:
|
|||
|
|
"""测试向后兼容性"""
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_kb_search_without_slot_state(self):
|
|||
|
|
"""测试不使用 slot_state 时 KB 检索仍然工作"""
|
|||
|
|
from app.services.mid.kb_search_dynamic_tool import (
|
|||
|
|
KbSearchDynamicConfig,
|
|||
|
|
KbSearchDynamicTool,
|
|||
|
|
)
|
|||
|
|
from app.services.mid.metadata_filter_builder import MetadataFilterBuilder
|
|||
|
|
|
|||
|
|
mock_session = AsyncMock()
|
|||
|
|
kb_tool = KbSearchDynamicTool(
|
|||
|
|
session=mock_session,
|
|||
|
|
config=KbSearchDynamicConfig(enabled=True),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 模拟 filter_builder 返回空结果
|
|||
|
|
with patch.object(
|
|||
|
|
MetadataFilterBuilder,
|
|||
|
|
"_get_filterable_fields",
|
|||
|
|
return_value=[],
|
|||
|
|
):
|
|||
|
|
with patch.object(
|
|||
|
|
kb_tool,
|
|||
|
|
"_retrieve_with_timeout",
|
|||
|
|
return_value=[],
|
|||
|
|
):
|
|||
|
|
result = await kb_tool.execute(
|
|||
|
|
query="退款政策",
|
|||
|
|
tenant_id="test_tenant",
|
|||
|
|
context={},
|
|||
|
|
slot_state=None, # 不提供 slot_state
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 应该成功执行
|
|||
|
|
assert result.success is True
|
|||
|
|
assert result.fallback_reason_code is None
|
|||
|
|
|
|||
|
|
@pytest.mark.asyncio
|
|||
|
|
async def test_legacy_context_filter(self):
|
|||
|
|
"""测试使用传统 context 构建过滤器"""
|
|||
|
|
from app.services.mid.kb_search_dynamic_tool import (
|
|||
|
|
KbSearchDynamicConfig,
|
|||
|
|
KbSearchDynamicTool,
|
|||
|
|
)
|
|||
|
|
from app.services.mid.metadata_filter_builder import MetadataFilterBuilder
|
|||
|
|
|
|||
|
|
mock_session = AsyncMock()
|
|||
|
|
kb_tool = KbSearchDynamicTool(
|
|||
|
|
session=mock_session,
|
|||
|
|
config=KbSearchDynamicConfig(enabled=True),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 使用简单 context
|
|||
|
|
context = {"product_line": "vip_course", "region": "beijing"}
|
|||
|
|
|
|||
|
|
with patch.object(
|
|||
|
|
MetadataFilterBuilder,
|
|||
|
|
"_get_filterable_fields",
|
|||
|
|
return_value=[],
|
|||
|
|
):
|
|||
|
|
with patch.object(
|
|||
|
|
kb_tool,
|
|||
|
|
"_retrieve_with_timeout",
|
|||
|
|
return_value=[],
|
|||
|
|
):
|
|||
|
|
result = await kb_tool.execute(
|
|||
|
|
query="退款政策",
|
|||
|
|
tenant_id="test_tenant",
|
|||
|
|
context=context,
|
|||
|
|
slot_state=None,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 应该成功执行
|
|||
|
|
assert result.success is True
|
|||
|
|
# 简单 context 应该直接使用作为 filter
|
|||
|
|
assert result.applied_filter.get("product_line") == "vip_course"
|