ai-robot-core/ai-service/tests/test_dialogue_slot_integrat...

309 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Tests for Dialogue API with Slot State Integration.
[AC-MRS-SLOT-META-03] 对话 API 与槽位状态集成测试
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.api.mid.dialogue import _generate_ask_back_for_missing_slots
from app.models.mid.schemas import ExecutionMode, Segment, TraceInfo
from app.services.mid.slot_state_aggregator import SlotState
class TestGenerateAskBackResponse:
"""测试生成追问响应"""
@pytest.mark.asyncio
async def test_generate_ask_back_with_prompt(self):
"""测试使用配置的 ask_back_prompt 生成追问"""
slot_state = SlotState()
missing_slots = [
{
"slot_key": "region",
"label": "地区",
"ask_back_prompt": "请问您在哪个地区?",
}
]
mock_session = AsyncMock()
response = await _generate_ask_back_for_missing_slots(
slot_state=slot_state,
missing_slots=missing_slots,
session=mock_session,
tenant_id="test_tenant",
)
assert response == "请问您在哪个地区?"
@pytest.mark.asyncio
async def test_generate_ask_back_generic(self):
"""测试使用通用模板生成追问"""
slot_state = SlotState()
missing_slots = [
{
"slot_key": "product_line",
"label": "产品线",
# 没有 ask_back_prompt
}
]
mock_session = AsyncMock()
response = await _generate_ask_back_for_missing_slots(
slot_state=slot_state,
missing_slots=missing_slots,
session=mock_session,
tenant_id="test_tenant",
)
assert "产品线" in response
@pytest.mark.asyncio
async def test_generate_ask_back_empty_slots(self):
"""测试空缺失槽位列表"""
slot_state = SlotState()
missing_slots = []
mock_session = AsyncMock()
response = await _generate_ask_back_for_missing_slots(
slot_state=slot_state,
missing_slots=missing_slots,
session=mock_session,
tenant_id="test_tenant",
)
assert "更多信息" in response
class TestDialogueAskBackResponse:
"""测试对话追问响应"""
def test_dialogue_response_with_ask_back(self):
"""测试追问响应的结构"""
from app.models.mid.schemas import DialogueResponse
response = DialogueResponse(
segments=[Segment(text="请问您咨询的是哪个产品线?", delay_after=0)],
trace=TraceInfo(
mode=ExecutionMode.AGENT,
request_id="test_request_id",
generation_id="test_generation_id",
fallback_reason_code="missing_required_slots",
kb_tool_called=True,
kb_hit=False,
),
)
assert len(response.segments) == 1
assert "哪个产品线" in response.segments[0].text
assert response.trace.fallback_reason_code == "missing_required_slots"
assert response.trace.kb_tool_called is True
assert response.trace.kb_hit is False
class TestSlotStateAggregationFlow:
"""测试槽位状态聚合流程"""
@pytest.mark.asyncio
async def test_memory_slots_included_in_state(self):
"""测试 memory_recall 的槽位被包含在状态中"""
from app.models.mid.schemas import MemorySlot, SlotSource
from app.services.mid.slot_state_aggregator import SlotStateAggregator
mock_session = AsyncMock()
aggregator = SlotStateAggregator(
session=mock_session,
tenant_id="test_tenant",
)
memory_slots = {
"product_line": MemorySlot(
key="product_line",
value="vip_course",
source=SlotSource.USER_CONFIRMED,
confidence=1.0,
)
}
with patch.object(
aggregator._slot_def_service,
"list_slot_definitions",
return_value=[],
):
state = await aggregator.aggregate(
memory_slots=memory_slots,
current_input_slots=None,
context=None,
)
assert "product_line" in state.filled_slots
assert state.filled_slots["product_line"] == "vip_course"
assert state.slot_sources["product_line"] == "user_confirmed"
@pytest.mark.asyncio
async def test_missing_slots_identified(self):
"""测试缺失的必填槽位被正确识别"""
from unittest.mock import MagicMock
from app.services.mid.slot_state_aggregator import SlotStateAggregator
mock_session = AsyncMock()
aggregator = SlotStateAggregator(
session=mock_session,
tenant_id="test_tenant",
)
# 模拟一个 required 的槽位定义
mock_slot_def = MagicMock()
mock_slot_def.slot_key = "region"
mock_slot_def.required = True
mock_slot_def.ask_back_prompt = "请问您在哪个地区?"
mock_slot_def.linked_field_id = None
with patch.object(
aggregator._slot_def_service,
"list_slot_definitions",
return_value=[mock_slot_def],
):
state = await aggregator.aggregate(
memory_slots={},
current_input_slots=None,
context=None,
)
assert len(state.missing_required_slots) == 1
assert state.missing_required_slots[0]["slot_key"] == "region"
assert state.missing_required_slots[0]["ask_back_prompt"] == "请问您在哪个地区?"
class TestSlotMetadataLinkage:
"""测试槽位与元数据关联"""
@pytest.mark.asyncio
async def test_slot_to_field_mapping(self):
"""测试槽位到元数据字段的映射"""
from unittest.mock import MagicMock, patch
from app.services.mid.slot_state_aggregator import SlotStateAggregator
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
mock_session = AsyncMock()
aggregator = SlotStateAggregator(
session=mock_session,
tenant_id="test_tenant",
)
# 模拟槽位定义(带 linked_field_id
mock_slot_def = MagicMock()
mock_slot_def.slot_key = "product"
mock_slot_def.linked_field_id = "field-uuid-123"
mock_slot_def.required = False
mock_slot_def.type = "string"
mock_slot_def.options = None
# 模拟关联的元数据字段
mock_field = MagicMock()
mock_field.field_key = "product_line"
mock_field.label = "产品线"
mock_field.type = "string"
mock_field.required = False
mock_field.options = None
with patch.object(
aggregator._slot_def_service,
"list_slot_definitions",
return_value=[mock_slot_def],
):
with patch.object(
MetadataFieldDefinitionService,
"get_field_definition",
return_value=mock_field,
):
state = await aggregator.aggregate(
memory_slots={},
current_input_slots=None,
context=None,
)
# 验证映射已建立
assert state.slot_to_field_map.get("product") == "product_line"
class TestBackwardCompatibility:
"""测试向后兼容性"""
@pytest.mark.asyncio
async def test_kb_search_without_slot_state(self):
"""测试不使用 slot_state 时 KB 检索仍然工作"""
from app.services.mid.kb_search_dynamic_tool import (
KbSearchDynamicConfig,
KbSearchDynamicTool,
)
from app.services.mid.metadata_filter_builder import MetadataFilterBuilder
mock_session = AsyncMock()
kb_tool = KbSearchDynamicTool(
session=mock_session,
config=KbSearchDynamicConfig(enabled=True),
)
# 模拟 filter_builder 返回空结果
with patch.object(
MetadataFilterBuilder,
"_get_filterable_fields",
return_value=[],
):
with patch.object(
kb_tool,
"_retrieve_with_timeout",
return_value=[],
):
result = await kb_tool.execute(
query="退款政策",
tenant_id="test_tenant",
context={},
slot_state=None, # 不提供 slot_state
)
# 应该成功执行
assert result.success is True
assert result.fallback_reason_code is None
@pytest.mark.asyncio
async def test_legacy_context_filter(self):
"""测试使用传统 context 构建过滤器"""
from app.services.mid.kb_search_dynamic_tool import (
KbSearchDynamicConfig,
KbSearchDynamicTool,
)
from app.services.mid.metadata_filter_builder import MetadataFilterBuilder
mock_session = AsyncMock()
kb_tool = KbSearchDynamicTool(
session=mock_session,
config=KbSearchDynamicConfig(enabled=True),
)
# 使用简单 context
context = {"product_line": "vip_course", "region": "beijing"}
with patch.object(
MetadataFilterBuilder,
"_get_filterable_fields",
return_value=[],
):
with patch.object(
kb_tool,
"_retrieve_with_timeout",
return_value=[],
):
result = await kb_tool.execute(
query="退款政策",
tenant_id="test_tenant",
context=context,
slot_state=None,
)
# 应该成功执行
assert result.success is True
# 简单 context 应该直接使用作为 filter
assert result.applied_filter.get("product_line") == "vip_course"