882 lines
31 KiB
Python
882 lines
31 KiB
Python
"""
|
||
Mid Platform Dialogue Integration Test.
|
||
中台联调界面对话过程集成测试脚本
|
||
|
||
测试重点:
|
||
1. 意图置信度参数
|
||
2. 执行模式(通用API vs ReAct模式)
|
||
3. ReAct模式下的工具调用(工具名称、入参、返回结果)
|
||
4. 知识库查询(是否命中、入参)
|
||
5. 各部分耗时
|
||
6. 提示词模板使用情况
|
||
"""
|
||
|
||
import asyncio
|
||
import json
|
||
import logging
|
||
import time
|
||
import uuid
|
||
from dataclasses import dataclass, field
|
||
from datetime import datetime
|
||
from typing import Any
|
||
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
||
import pytest
|
||
from fastapi.testclient import TestClient
|
||
from httpx import AsyncClient
|
||
|
||
from app.main import app
|
||
from app.models.mid.schemas import (
|
||
DialogueRequest,
|
||
DialogueResponse,
|
||
ExecutionMode,
|
||
FeatureFlags,
|
||
HistoryMessage,
|
||
Segment,
|
||
TraceInfo,
|
||
ToolCallTrace,
|
||
ToolCallStatus,
|
||
ToolType,
|
||
IntentHintOutput,
|
||
HighRiskCheckResult,
|
||
HighRiskScenario,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class TimingRecord:
|
||
"""耗时记录"""
|
||
stage: str
|
||
start_time: float
|
||
end_time: float
|
||
duration_ms: int
|
||
|
||
def to_dict(self) -> dict:
|
||
return {
|
||
"stage": self.stage,
|
||
"duration_ms": self.duration_ms,
|
||
}
|
||
|
||
|
||
@dataclass
|
||
class DialogueTestResult:
|
||
"""对话测试结果"""
|
||
request_id: str
|
||
user_message: str
|
||
response_text: str = ""
|
||
|
||
timing_records: list[TimingRecord] = field(default_factory=list)
|
||
total_duration_ms: int = 0
|
||
|
||
execution_mode: ExecutionMode = ExecutionMode.AGENT
|
||
intent: str | None = None
|
||
confidence: float | None = None
|
||
|
||
react_iterations: int = 0
|
||
tools_used: list[str] = field(default_factory=list)
|
||
tool_calls: list[dict] = field(default_factory=list)
|
||
|
||
kb_tool_called: bool = False
|
||
kb_hit: bool = False
|
||
kb_query: str | None = None
|
||
kb_filter: dict | None = None
|
||
kb_hits_count: int = 0
|
||
|
||
prompt_template_used: str | None = None
|
||
prompt_template_scene: str | None = None
|
||
|
||
guardrail_triggered: bool = False
|
||
fallback_reason_code: str | None = None
|
||
|
||
raw_trace: dict | None = None
|
||
|
||
def to_summary(self) -> dict:
|
||
return {
|
||
"request_id": self.request_id,
|
||
"user_message": self.user_message[:100],
|
||
"execution_mode": self.execution_mode.value,
|
||
"intent": self.intent,
|
||
"confidence": self.confidence,
|
||
"total_duration_ms": self.total_duration_ms,
|
||
"timing_breakdown": [t.to_dict() for t in self.timing_records],
|
||
"react_iterations": self.react_iterations,
|
||
"tools_used": self.tools_used,
|
||
"tool_calls_count": len(self.tool_calls),
|
||
"kb_tool_called": self.kb_tool_called,
|
||
"kb_hit": self.kb_hit,
|
||
"kb_hits_count": self.kb_hits_count,
|
||
"prompt_template_used": self.prompt_template_used,
|
||
"guardrail_triggered": self.guardrail_triggered,
|
||
"fallback_reason_code": self.fallback_reason_code,
|
||
}
|
||
|
||
|
||
class DialogueIntegrationTester:
|
||
"""对话集成测试器"""
|
||
|
||
def __init__(
|
||
self,
|
||
base_url: str = "http://localhost:8000",
|
||
tenant_id: str = "test_tenant",
|
||
api_key: str | None = None,
|
||
):
|
||
self.base_url = base_url
|
||
self.tenant_id = tenant_id
|
||
self.api_key = api_key
|
||
self.session_id = f"test_session_{uuid.uuid4().hex[:8]}"
|
||
self.user_id = f"test_user_{uuid.uuid4().hex[:8]}"
|
||
|
||
def _get_headers(self) -> dict:
|
||
headers = {
|
||
"Content-Type": "application/json",
|
||
"X-Tenant-Id": self.tenant_id,
|
||
}
|
||
if self.api_key:
|
||
headers["X-API-Key"] = self.api_key
|
||
return headers
|
||
|
||
async def send_dialogue(
|
||
self,
|
||
user_message: str,
|
||
history: list[dict] | None = None,
|
||
scene: str | None = None,
|
||
feature_flags: dict | None = None,
|
||
) -> DialogueTestResult:
|
||
"""发送对话请求并记录详细信息"""
|
||
request_id = str(uuid.uuid4())
|
||
result = DialogueTestResult(
|
||
request_id=request_id,
|
||
user_message=user_message,
|
||
response_text="",
|
||
)
|
||
|
||
overall_start = time.time()
|
||
|
||
request_body = {
|
||
"session_id": self.session_id,
|
||
"user_id": self.user_id,
|
||
"user_message": user_message,
|
||
"history": history or [],
|
||
}
|
||
|
||
if scene:
|
||
request_body["scene"] = scene
|
||
if feature_flags:
|
||
request_body["feature_flags"] = feature_flags
|
||
|
||
timing_records = []
|
||
|
||
try:
|
||
async with AsyncClient(base_url=self.base_url, timeout=120.0) as client:
|
||
request_start = time.time()
|
||
|
||
response = await client.post(
|
||
"/mid/dialogue/respond",
|
||
json=request_body,
|
||
headers=self._get_headers(),
|
||
)
|
||
|
||
request_end = time.time()
|
||
timing_records.append(TimingRecord(
|
||
stage="http_request",
|
||
start_time=request_start,
|
||
end_time=request_end,
|
||
duration_ms=int((request_end - request_start) * 1000),
|
||
))
|
||
|
||
if response.status_code != 200:
|
||
result.response_text = f"Error: {response.status_code}"
|
||
result.fallback_reason_code = f"http_error_{response.status_code}"
|
||
return result
|
||
|
||
response_data = response.json()
|
||
|
||
except Exception as e:
|
||
result.response_text = f"Exception: {str(e)}"
|
||
result.fallback_reason_code = "request_exception"
|
||
result.total_duration_ms = int((time.time() - overall_start) * 1000)
|
||
return result
|
||
|
||
overall_end = time.time()
|
||
result.total_duration_ms = int((overall_end - overall_start) * 1000)
|
||
result.timing_records = timing_records
|
||
|
||
try:
|
||
dialogue_response = DialogueResponse(**response_data)
|
||
|
||
if dialogue_response.segments:
|
||
result.response_text = "\n".join(s.text for s in dialogue_response.segments)
|
||
|
||
trace = dialogue_response.trace
|
||
result.raw_trace = trace.model_dump() if trace else None
|
||
|
||
if trace:
|
||
result.execution_mode = trace.mode
|
||
result.intent = trace.intent
|
||
result.react_iterations = trace.react_iterations or 0
|
||
result.tools_used = trace.tools_used or []
|
||
result.kb_tool_called = trace.kb_tool_called or False
|
||
result.kb_hit = trace.kb_hit or False
|
||
result.guardrail_triggered = trace.guardrail_triggered or False
|
||
result.fallback_reason_code = trace.fallback_reason_code
|
||
|
||
if trace.tool_calls:
|
||
result.tool_calls = [tc.model_dump() for tc in trace.tool_calls]
|
||
|
||
for tc in trace.tool_calls:
|
||
if tc.tool_name == "kb_search_dynamic":
|
||
result.kb_tool_called = True
|
||
if tc.arguments:
|
||
result.kb_query = tc.arguments.get("query")
|
||
result.kb_filter = tc.arguments.get("context")
|
||
if tc.result and isinstance(tc.result, dict):
|
||
result.kb_hits_count = len(tc.result.get("hits", []))
|
||
result.kb_hit = result.kb_hits_count > 0
|
||
|
||
if trace.duration_ms:
|
||
timing_records.append(TimingRecord(
|
||
stage="server_processing",
|
||
start_time=overall_start,
|
||
end_time=overall_end,
|
||
duration_ms=trace.duration_ms,
|
||
))
|
||
|
||
if trace.scene:
|
||
result.prompt_template_scene = trace.scene
|
||
|
||
except Exception as e:
|
||
logger.error(f"Failed to parse response: {e}")
|
||
result.response_text = str(response_data)
|
||
|
||
return result
|
||
|
||
def print_result(self, result: DialogueTestResult):
|
||
"""打印测试结果"""
|
||
print("\n" + "=" * 80)
|
||
print(f"[对话测试结果] Request ID: {result.request_id}")
|
||
print("=" * 80)
|
||
|
||
print(f"\n[用户消息] {result.user_message}")
|
||
print(f"[回复内容] {result.response_text[:200]}...")
|
||
|
||
print(f"\n[执行模式] {result.execution_mode.value}")
|
||
print(f"[意图识别] intent={result.intent}, confidence={result.confidence}")
|
||
|
||
print(f"\n[耗时统计] 总耗时: {result.total_duration_ms}ms")
|
||
for tr in result.timing_records:
|
||
print(f" - {tr.stage}: {tr.duration_ms}ms")
|
||
|
||
if result.execution_mode == ExecutionMode.AGENT:
|
||
print(f"\n[ReAct模式]")
|
||
print(f" - 迭代次数: {result.react_iterations}")
|
||
print(f" - 使用的工具: {result.tools_used}")
|
||
|
||
if result.tool_calls:
|
||
print(f"\n[工具调用详情]")
|
||
for i, tc in enumerate(result.tool_calls, 1):
|
||
print(f" [{i}] 工具: {tc.get('tool_name')}")
|
||
print(f" 状态: {tc.get('status')}")
|
||
print(f" 耗时: {tc.get('duration_ms')}ms")
|
||
if tc.get('arguments'):
|
||
print(f" 入参: {json.dumps(tc.get('arguments'), ensure_ascii=False)[:200]}")
|
||
if tc.get('result'):
|
||
result_str = str(tc.get('result'))[:300]
|
||
print(f" 结果: {result_str}")
|
||
|
||
print(f"\n[知识库查询]")
|
||
print(f" - 是否调用: {result.kb_tool_called}")
|
||
print(f" - 是否命中: {result.kb_hit}")
|
||
if result.kb_query:
|
||
print(f" - 查询内容: {result.kb_query}")
|
||
if result.kb_filter:
|
||
print(f" - 过滤条件: {json.dumps(result.kb_filter, ensure_ascii=False)[:200]}")
|
||
print(f" - 命中数量: {result.kb_hits_count}")
|
||
|
||
print(f"\n[提示词模板]")
|
||
print(f" - 场景: {result.prompt_template_scene}")
|
||
print(f" - 使用模板: {result.prompt_template_used or '默认模板'}")
|
||
|
||
print(f"\n[其他信息]")
|
||
print(f" - 护栏触发: {result.guardrail_triggered}")
|
||
print(f" - 降级原因: {result.fallback_reason_code or '无'}")
|
||
|
||
print("\n" + "=" * 80)
|
||
|
||
|
||
class TestMidDialogueIntegration:
|
||
"""中台对话集成测试"""
|
||
|
||
@pytest.fixture
|
||
def tester(self):
|
||
return DialogueIntegrationTester(
|
||
base_url="http://localhost:8000",
|
||
tenant_id="test_tenant",
|
||
)
|
||
|
||
@pytest.fixture
|
||
def mock_llm_client(self):
|
||
"""模拟 LLM 客户端"""
|
||
mock = AsyncMock()
|
||
mock.generate = AsyncMock(return_value=MagicMock(
|
||
content="这是测试回复",
|
||
has_tool_calls=False,
|
||
tool_calls=[],
|
||
))
|
||
return mock
|
||
|
||
@pytest.fixture
|
||
def mock_kb_tool(self):
|
||
"""模拟知识库工具"""
|
||
mock = AsyncMock()
|
||
mock.execute = AsyncMock(return_value=MagicMock(
|
||
success=True,
|
||
hits=[
|
||
{"id": "1", "content": "测试知识库内容", "score": 0.9},
|
||
],
|
||
applied_filter={"scene": "test"},
|
||
missing_required_slots=[],
|
||
fallback_reason_code=None,
|
||
duration_ms=100,
|
||
tool_trace=None,
|
||
))
|
||
return mock
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_simple_greeting(self, tester: DialogueIntegrationTester):
|
||
"""测试简单问候"""
|
||
result = await tester.send_dialogue(
|
||
user_message="你好",
|
||
)
|
||
|
||
tester.print_result(result)
|
||
|
||
assert result.request_id is not None
|
||
assert result.total_duration_ms > 0
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_kb_query(self, tester: DialogueIntegrationTester):
|
||
"""测试知识库查询"""
|
||
result = await tester.send_dialogue(
|
||
user_message="退款流程是什么?",
|
||
scene="after_sale",
|
||
)
|
||
|
||
tester.print_result(result)
|
||
|
||
assert result.request_id is not None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_high_risk_scenario(self, tester: DialogueIntegrationTester):
|
||
"""测试高风险场景"""
|
||
result = await tester.send_dialogue(
|
||
user_message="我要投诉你们的服务",
|
||
)
|
||
|
||
tester.print_result(result)
|
||
|
||
assert result.request_id is not None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_transfer_request(self, tester: DialogueIntegrationTester):
|
||
"""测试转人工请求"""
|
||
result = await tester.send_dialogue(
|
||
user_message="帮我转人工客服",
|
||
)
|
||
|
||
tester.print_result(result)
|
||
|
||
assert result.request_id is not None
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_with_history(self, tester: DialogueIntegrationTester):
|
||
"""测试带历史记录的对话"""
|
||
result = await tester.send_dialogue(
|
||
user_message="那退款要多久呢?",
|
||
history=[
|
||
{"role": "user", "content": "我想退款"},
|
||
{"role": "assistant", "content": "好的,请问您要退款的订单号是多少?"},
|
||
],
|
||
)
|
||
|
||
tester.print_result(result)
|
||
|
||
assert result.request_id is not None
|
||
|
||
|
||
class TestDialogueWithMock:
|
||
"""使用 Mock 的对话测试"""
|
||
|
||
@pytest.fixture
|
||
def mock_app(self):
|
||
"""创建带 Mock 的测试应用"""
|
||
from fastapi import FastAPI
|
||
from app.api.mid.dialogue import router
|
||
|
||
app = FastAPI()
|
||
app.include_router(router)
|
||
return app
|
||
|
||
@pytest.fixture
|
||
def client(self, mock_app):
|
||
return TestClient(mock_app)
|
||
|
||
@pytest.fixture
|
||
def mock_session(self):
|
||
"""模拟数据库会话"""
|
||
mock = AsyncMock()
|
||
mock_result = MagicMock()
|
||
mock_result.scalars.return_value.all.return_value = []
|
||
mock.execute.return_value = mock_result
|
||
return mock
|
||
|
||
@pytest.fixture
|
||
def mock_llm(self):
|
||
"""模拟 LLM 响应"""
|
||
mock_response = MagicMock()
|
||
mock_response.content = "这是测试回复内容"
|
||
mock_response.has_tool_calls = False
|
||
mock_response.tool_calls = []
|
||
return mock_response
|
||
|
||
def test_dialogue_request_structure(self, client: TestClient):
|
||
"""测试对话请求结构"""
|
||
request_body = {
|
||
"session_id": "test_session_001",
|
||
"user_id": "test_user_001",
|
||
"user_message": "你好",
|
||
"history": [],
|
||
"scene": "open_consult",
|
||
}
|
||
|
||
print("\n[测试请求结构]")
|
||
print(f"Request Body: {json.dumps(request_body, ensure_ascii=False, indent=2)}")
|
||
|
||
assert request_body["session_id"] == "test_session_001"
|
||
assert request_body["user_message"] == "你好"
|
||
|
||
def test_trace_info_structure(self):
|
||
"""测试追踪信息结构"""
|
||
trace = TraceInfo(
|
||
mode=ExecutionMode.AGENT,
|
||
intent="greeting",
|
||
request_id=str(uuid.uuid4()),
|
||
generation_id=str(uuid.uuid4()),
|
||
kb_tool_called=True,
|
||
kb_hit=True,
|
||
react_iterations=2,
|
||
tools_used=["kb_search_dynamic"],
|
||
tool_calls=[
|
||
ToolCallTrace(
|
||
tool_name="kb_search_dynamic",
|
||
tool_type=ToolType.INTERNAL,
|
||
duration_ms=150,
|
||
status=ToolCallStatus.OK,
|
||
arguments={"query": "测试查询", "scene": "test"},
|
||
result={"hits": [{"content": "测试内容"}]},
|
||
),
|
||
],
|
||
)
|
||
|
||
print("\n[TraceInfo 结构测试]")
|
||
print(f"Mode: {trace.mode.value}")
|
||
print(f"Intent: {trace.intent}")
|
||
print(f"KB Tool Called: {trace.kb_tool_called}")
|
||
print(f"KB Hit: {trace.kb_hit}")
|
||
print(f"React Iterations: {trace.react_iterations}")
|
||
print(f"Tools Used: {trace.tools_used}")
|
||
|
||
if trace.tool_calls:
|
||
print(f"\n[Tool Calls]")
|
||
for tc in trace.tool_calls:
|
||
print(f" - Tool: {tc.tool_name}")
|
||
print(f" Status: {tc.status.value}")
|
||
print(f" Duration: {tc.duration_ms}ms")
|
||
if tc.arguments:
|
||
print(f" Arguments: {json.dumps(tc.arguments, ensure_ascii=False)}")
|
||
|
||
assert trace.mode == ExecutionMode.AGENT
|
||
assert trace.kb_tool_called is True
|
||
assert len(trace.tool_calls) == 1
|
||
|
||
def test_intent_hint_output_structure(self):
|
||
"""测试意图提示输出结构"""
|
||
hint = IntentHintOutput(
|
||
intent="refund",
|
||
confidence=0.85,
|
||
response_type="flow",
|
||
suggested_mode=ExecutionMode.MICRO_FLOW,
|
||
target_flow_id="flow_refund_001",
|
||
high_risk_detected=False,
|
||
duration_ms=50,
|
||
)
|
||
|
||
print("\n[IntentHintOutput 结构测试]")
|
||
print(f"Intent: {hint.intent}")
|
||
print(f"Confidence: {hint.confidence}")
|
||
print(f"Response Type: {hint.response_type}")
|
||
print(f"Suggested Mode: {hint.suggested_mode.value if hint.suggested_mode else None}")
|
||
print(f"High Risk Detected: {hint.high_risk_detected}")
|
||
print(f"Duration: {hint.duration_ms}ms")
|
||
|
||
assert hint.intent == "refund"
|
||
assert hint.confidence == 0.85
|
||
assert hint.suggested_mode == ExecutionMode.MICRO_FLOW
|
||
|
||
def test_high_risk_check_result_structure(self):
|
||
"""测试高风险检测结果结构"""
|
||
result = HighRiskCheckResult(
|
||
matched=True,
|
||
risk_scenario=HighRiskScenario.REFUND,
|
||
confidence=0.95,
|
||
recommended_mode=ExecutionMode.MICRO_FLOW,
|
||
rule_id="rule_refund_001",
|
||
reason="检测到退款关键词",
|
||
duration_ms=30,
|
||
)
|
||
|
||
print("\n[HighRiskCheckResult 结构测试]")
|
||
print(f"Matched: {result.matched}")
|
||
print(f"Risk Scenario: {result.risk_scenario.value if result.risk_scenario else None}")
|
||
print(f"Confidence: {result.confidence}")
|
||
print(f"Recommended Mode: {result.recommended_mode.value if result.recommended_mode else None}")
|
||
print(f"Rule ID: {result.rule_id}")
|
||
print(f"Duration: {result.duration_ms}ms")
|
||
|
||
assert result.matched is True
|
||
assert result.risk_scenario == HighRiskScenario.REFUND
|
||
|
||
|
||
class TestPromptTemplateUsage:
|
||
"""提示词模板使用测试"""
|
||
|
||
def test_template_resolution(self):
|
||
"""测试模板解析"""
|
||
from app.services.prompt.variable_resolver import VariableResolver
|
||
|
||
resolver = VariableResolver()
|
||
|
||
template = "你好,{{user_name}}!我是{{bot_name}},很高兴为您服务。"
|
||
variables = [
|
||
{"key": "user_name", "value": "张三"},
|
||
{"key": "bot_name", "value": "智能客服"},
|
||
]
|
||
|
||
resolved = resolver.resolve(template, variables)
|
||
|
||
print("\n[模板解析测试]")
|
||
print(f"原始模板: {template}")
|
||
print(f"变量: {json.dumps(variables, ensure_ascii=False)}")
|
||
print(f"解析结果: {resolved}")
|
||
|
||
assert resolved == "你好,张三!我是智能客服,很高兴为您服务。"
|
||
|
||
def test_template_with_extra_context(self):
|
||
"""测试带额外上下文的模板解析"""
|
||
from app.services.prompt.variable_resolver import VariableResolver
|
||
|
||
resolver = VariableResolver()
|
||
|
||
template = "当前场景:{{scene}},用户问题:{{query}}"
|
||
extra_context = {
|
||
"scene": "售后服务",
|
||
"query": "退款流程",
|
||
}
|
||
|
||
resolved = resolver.resolve(template, [], extra_context)
|
||
|
||
print("\n[带上下文的模板解析测试]")
|
||
print(f"原始模板: {template}")
|
||
print(f"额外上下文: {json.dumps(extra_context, ensure_ascii=False)}")
|
||
print(f"解析结果: {resolved}")
|
||
|
||
assert "售后服务" in resolved
|
||
assert "退款流程" in resolved
|
||
|
||
|
||
class TestToolCallRecording:
|
||
"""工具调用记录测试"""
|
||
|
||
def test_tool_call_trace_creation(self):
|
||
"""测试工具调用追踪创建"""
|
||
trace = ToolCallTrace(
|
||
tool_name="kb_search_dynamic",
|
||
tool_type=ToolType.INTERNAL,
|
||
duration_ms=150,
|
||
status=ToolCallStatus.OK,
|
||
args_digest="query=退款流程",
|
||
result_digest="hits=3",
|
||
arguments={
|
||
"query": "退款流程是什么",
|
||
"scene": "after_sale",
|
||
"context": {"product_type": "vip"},
|
||
},
|
||
result={
|
||
"success": True,
|
||
"hits": [
|
||
{"id": "1", "content": "退款流程说明...", "score": 0.95},
|
||
{"id": "2", "content": "退款注意事项...", "score": 0.88},
|
||
{"id": "3", "content": "退款时效说明...", "score": 0.82},
|
||
],
|
||
"applied_filter": {"product_type": "vip"},
|
||
},
|
||
)
|
||
|
||
print("\n[工具调用追踪测试]")
|
||
print(f"Tool Name: {trace.tool_name}")
|
||
print(f"Tool Type: {trace.tool_type.value}")
|
||
print(f"Status: {trace.status.value}")
|
||
print(f"Duration: {trace.duration_ms}ms")
|
||
print(f"\n[入参详情]")
|
||
if trace.arguments:
|
||
for key, value in trace.arguments.items():
|
||
print(f" - {key}: {value}")
|
||
print(f"\n[返回结果]")
|
||
if trace.result:
|
||
if isinstance(trace.result, dict):
|
||
print(f" - success: {trace.result.get('success')}")
|
||
print(f" - hits count: {len(trace.result.get('hits', []))}")
|
||
for i, hit in enumerate(trace.result.get('hits', [])[:2], 1):
|
||
print(f" - hit[{i}]: score={hit.get('score')}, content={hit.get('content')[:30]}...")
|
||
|
||
assert trace.tool_name == "kb_search_dynamic"
|
||
assert trace.status == ToolCallStatus.OK
|
||
assert trace.arguments is not None
|
||
assert trace.result is not None
|
||
|
||
def test_tool_call_timeout_trace(self):
|
||
"""测试工具调用超时追踪"""
|
||
trace = ToolCallTrace(
|
||
tool_name="kb_search_dynamic",
|
||
tool_type=ToolType.INTERNAL,
|
||
duration_ms=2000,
|
||
status=ToolCallStatus.TIMEOUT,
|
||
error_code="TOOL_TIMEOUT",
|
||
arguments={"query": "测试查询"},
|
||
)
|
||
|
||
print("\n[工具调用超时追踪测试]")
|
||
print(f"Tool Name: {trace.tool_name}")
|
||
print(f"Status: {trace.status.value}")
|
||
print(f"Error Code: {trace.error_code}")
|
||
print(f"Duration: {trace.duration_ms}ms")
|
||
|
||
assert trace.status == ToolCallStatus.TIMEOUT
|
||
assert trace.error_code == "TOOL_TIMEOUT"
|
||
|
||
|
||
class TestExecutionModeRouting:
|
||
"""执行模式路由测试"""
|
||
|
||
def test_policy_router_decision(self):
|
||
"""测试策略路由器决策"""
|
||
from app.services.mid.policy_router import PolicyRouter, IntentMatch
|
||
|
||
router = PolicyRouter()
|
||
|
||
test_cases = [
|
||
{
|
||
"name": "正常对话 -> Agent模式",
|
||
"user_message": "你好,请问有什么可以帮助我的?",
|
||
"session_mode": "BOT_ACTIVE",
|
||
"expected_mode": ExecutionMode.AGENT,
|
||
},
|
||
{
|
||
"name": "高风险退款 -> Micro Flow模式",
|
||
"user_message": "我要退款",
|
||
"session_mode": "BOT_ACTIVE",
|
||
"expected_mode": ExecutionMode.MICRO_FLOW,
|
||
},
|
||
{
|
||
"name": "转人工请求 -> Transfer模式",
|
||
"user_message": "帮我转人工",
|
||
"session_mode": "BOT_ACTIVE",
|
||
"expected_mode": ExecutionMode.TRANSFER,
|
||
},
|
||
{
|
||
"name": "人工模式 -> Transfer模式",
|
||
"user_message": "你好",
|
||
"session_mode": "HUMAN_ACTIVE",
|
||
"expected_mode": ExecutionMode.TRANSFER,
|
||
},
|
||
]
|
||
|
||
print("\n[策略路由器决策测试]")
|
||
|
||
for tc in test_cases:
|
||
result = router.route(
|
||
user_message=tc["user_message"],
|
||
session_mode=tc["session_mode"],
|
||
)
|
||
|
||
print(f"\n测试用例: {tc['name']}")
|
||
print(f" 用户消息: {tc['user_message']}")
|
||
print(f" 会话模式: {tc['session_mode']}")
|
||
print(f" 期望模式: {tc['expected_mode'].value}")
|
||
print(f" 实际模式: {result.mode.value}")
|
||
|
||
assert result.mode == tc["expected_mode"], f"模式不匹配: {result.mode} != {tc['expected_mode']}"
|
||
|
||
|
||
class TestKBSearchDynamic:
|
||
"""知识库动态检索测试"""
|
||
|
||
def test_kb_search_result_structure(self):
|
||
"""测试知识库检索结果结构"""
|
||
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicResult
|
||
|
||
result = KbSearchDynamicResult(
|
||
success=True,
|
||
hits=[
|
||
{
|
||
"id": "chunk_001",
|
||
"content": "退款流程:1. 登录账户 2. 进入订单页面 3. 点击退款按钮...",
|
||
"score": 0.92,
|
||
"metadata": {"kb_id": "kb_001", "doc_id": "doc_001"},
|
||
},
|
||
{
|
||
"id": "chunk_002",
|
||
"content": "退款时效:一般3-5个工作日到账...",
|
||
"score": 0.85,
|
||
"metadata": {"kb_id": "kb_001", "doc_id": "doc_002"},
|
||
},
|
||
],
|
||
applied_filter={"scene": "after_sale", "product_type": "vip"},
|
||
missing_required_slots=[],
|
||
filter_debug={"source": "slot_state"},
|
||
filter_sources={"scene": "slot", "product_type": "context"},
|
||
duration_ms=120,
|
||
)
|
||
|
||
print("\n[知识库检索结果结构测试]")
|
||
print(f"Success: {result.success}")
|
||
print(f"Hits Count: {len(result.hits)}")
|
||
print(f"Applied Filter: {json.dumps(result.applied_filter, ensure_ascii=False)}")
|
||
print(f"Filter Sources: {json.dumps(result.filter_sources, ensure_ascii=False)}")
|
||
print(f"Duration: {result.duration_ms}ms")
|
||
|
||
print(f"\n[命中详情]")
|
||
for i, hit in enumerate(result.hits, 1):
|
||
print(f" [{i}] ID: {hit['id']}")
|
||
print(f" Score: {hit['score']}")
|
||
print(f" Content: {hit['content'][:50]}...")
|
||
|
||
assert result.success is True
|
||
assert len(result.hits) == 2
|
||
assert result.applied_filter is not None
|
||
|
||
def test_kb_search_missing_slots(self):
|
||
"""测试知识库检索缺失槽位"""
|
||
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicResult
|
||
|
||
result = KbSearchDynamicResult(
|
||
success=False,
|
||
hits=[],
|
||
applied_filter={},
|
||
missing_required_slots=[
|
||
{
|
||
"field_key": "order_id",
|
||
"label": "订单号",
|
||
"reason": "必填字段缺失",
|
||
"ask_back_prompt": "请提供您的订单号",
|
||
},
|
||
],
|
||
filter_debug={"source": "builder"},
|
||
fallback_reason_code="MISSING_REQUIRED_SLOTS",
|
||
duration_ms=50,
|
||
)
|
||
|
||
print("\n[知识库检索缺失槽位测试]")
|
||
print(f"Success: {result.success}")
|
||
print(f"Fallback Reason: {result.fallback_reason_code}")
|
||
print(f"Missing Slots: {json.dumps(result.missing_required_slots, ensure_ascii=False, indent=2)}")
|
||
|
||
assert result.success is False
|
||
assert result.fallback_reason_code == "MISSING_REQUIRED_SLOTS"
|
||
assert len(result.missing_required_slots) == 1
|
||
|
||
|
||
class TestTimingBreakdown:
|
||
"""耗时分解测试"""
|
||
|
||
def test_timing_breakdown_structure(self):
|
||
"""测试耗时分解结构"""
|
||
timings = [
|
||
TimingRecord("intent_matching", 0, 0.05, 50),
|
||
TimingRecord("high_risk_check", 0.05, 0.08, 30),
|
||
TimingRecord("kb_search", 0.08, 0.2, 120),
|
||
TimingRecord("llm_generation", 0.2, 1.5, 1300),
|
||
TimingRecord("output_guardrail", 1.5, 1.55, 50),
|
||
TimingRecord("response_formatting", 1.55, 1.6, 50),
|
||
]
|
||
|
||
total = sum(t.duration_ms for t in timings)
|
||
|
||
print("\n[耗时分解测试]")
|
||
print(f"{'阶段':<25} {'耗时(ms)':<10} {'占比':<10}")
|
||
print("-" * 45)
|
||
|
||
for t in timings:
|
||
percentage = (t.duration_ms / total * 100) if total > 0 else 0
|
||
print(f"{t.stage:<25} {t.duration_ms:<10} {percentage:.1f}%")
|
||
|
||
print("-" * 45)
|
||
print(f"{'总计':<25} {total:<10} {'100.0%':<10}")
|
||
|
||
assert total == 1600
|
||
|
||
|
||
def run_manual_test():
|
||
"""手动运行测试"""
|
||
import argparse
|
||
|
||
parser = argparse.ArgumentParser(description="中台对话集成测试")
|
||
parser.add_argument("--url", default="http://localhost:8000", help="服务地址")
|
||
parser.add_argument("--tenant", default="test_tenant", help="租户ID")
|
||
parser.add_argument("--api-key", default=None, help="API Key")
|
||
parser.add_argument("--message", default="你好", help="测试消息")
|
||
parser.add_argument("--scene", default=None, help="场景标识")
|
||
parser.add_argument("--interactive", action="store_true", help="交互模式")
|
||
|
||
args = parser.parse_args()
|
||
|
||
tester = DialogueIntegrationTester(
|
||
base_url=args.url,
|
||
tenant_id=args.tenant,
|
||
api_key=args.api_key,
|
||
)
|
||
|
||
if args.interactive:
|
||
print("\n=== 中台对话集成测试 - 交互模式 ===")
|
||
print("输入 'quit' 退出\n")
|
||
|
||
while True:
|
||
try:
|
||
message = input("请输入消息: ").strip()
|
||
if message.lower() == "quit":
|
||
break
|
||
|
||
scene = input("请输入场景(可选,直接回车跳过): ").strip() or None
|
||
|
||
result = asyncio.run(tester.send_dialogue(
|
||
user_message=message,
|
||
scene=scene,
|
||
))
|
||
|
||
tester.print_result(result)
|
||
|
||
except KeyboardInterrupt:
|
||
break
|
||
else:
|
||
result = asyncio.run(tester.send_dialogue(
|
||
user_message=args.message,
|
||
scene=args.scene,
|
||
))
|
||
|
||
tester.print_result(result)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
run_manual_test()
|