ai-robot-core/ai-service/tests/test_mid_dialogue_integrati...

882 lines
31 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Mid Platform Dialogue Integration Test.
中台联调界面对话过程集成测试脚本
测试重点:
1. 意图置信度参数
2. 执行模式通用API vs ReAct模式
3. ReAct模式下的工具调用工具名称、入参、返回结果
4. 知识库查询(是否命中、入参)
5. 各部分耗时
6. 提示词模板使用情况
"""
import asyncio
import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi.testclient import TestClient
from httpx import AsyncClient
from app.main import app
from app.models.mid.schemas import (
DialogueRequest,
DialogueResponse,
ExecutionMode,
FeatureFlags,
HistoryMessage,
Segment,
TraceInfo,
ToolCallTrace,
ToolCallStatus,
ToolType,
IntentHintOutput,
HighRiskCheckResult,
HighRiskScenario,
)
logger = logging.getLogger(__name__)
@dataclass
class TimingRecord:
"""耗时记录"""
stage: str
start_time: float
end_time: float
duration_ms: int
def to_dict(self) -> dict:
return {
"stage": self.stage,
"duration_ms": self.duration_ms,
}
@dataclass
class DialogueTestResult:
"""对话测试结果"""
request_id: str
user_message: str
response_text: str = ""
timing_records: list[TimingRecord] = field(default_factory=list)
total_duration_ms: int = 0
execution_mode: ExecutionMode = ExecutionMode.AGENT
intent: str | None = None
confidence: float | None = None
react_iterations: int = 0
tools_used: list[str] = field(default_factory=list)
tool_calls: list[dict] = field(default_factory=list)
kb_tool_called: bool = False
kb_hit: bool = False
kb_query: str | None = None
kb_filter: dict | None = None
kb_hits_count: int = 0
prompt_template_used: str | None = None
prompt_template_scene: str | None = None
guardrail_triggered: bool = False
fallback_reason_code: str | None = None
raw_trace: dict | None = None
def to_summary(self) -> dict:
return {
"request_id": self.request_id,
"user_message": self.user_message[:100],
"execution_mode": self.execution_mode.value,
"intent": self.intent,
"confidence": self.confidence,
"total_duration_ms": self.total_duration_ms,
"timing_breakdown": [t.to_dict() for t in self.timing_records],
"react_iterations": self.react_iterations,
"tools_used": self.tools_used,
"tool_calls_count": len(self.tool_calls),
"kb_tool_called": self.kb_tool_called,
"kb_hit": self.kb_hit,
"kb_hits_count": self.kb_hits_count,
"prompt_template_used": self.prompt_template_used,
"guardrail_triggered": self.guardrail_triggered,
"fallback_reason_code": self.fallback_reason_code,
}
class DialogueIntegrationTester:
"""对话集成测试器"""
def __init__(
self,
base_url: str = "http://localhost:8000",
tenant_id: str = "test_tenant",
api_key: str | None = None,
):
self.base_url = base_url
self.tenant_id = tenant_id
self.api_key = api_key
self.session_id = f"test_session_{uuid.uuid4().hex[:8]}"
self.user_id = f"test_user_{uuid.uuid4().hex[:8]}"
def _get_headers(self) -> dict:
headers = {
"Content-Type": "application/json",
"X-Tenant-Id": self.tenant_id,
}
if self.api_key:
headers["X-API-Key"] = self.api_key
return headers
async def send_dialogue(
self,
user_message: str,
history: list[dict] | None = None,
scene: str | None = None,
feature_flags: dict | None = None,
) -> DialogueTestResult:
"""发送对话请求并记录详细信息"""
request_id = str(uuid.uuid4())
result = DialogueTestResult(
request_id=request_id,
user_message=user_message,
response_text="",
)
overall_start = time.time()
request_body = {
"session_id": self.session_id,
"user_id": self.user_id,
"user_message": user_message,
"history": history or [],
}
if scene:
request_body["scene"] = scene
if feature_flags:
request_body["feature_flags"] = feature_flags
timing_records = []
try:
async with AsyncClient(base_url=self.base_url, timeout=120.0) as client:
request_start = time.time()
response = await client.post(
"/mid/dialogue/respond",
json=request_body,
headers=self._get_headers(),
)
request_end = time.time()
timing_records.append(TimingRecord(
stage="http_request",
start_time=request_start,
end_time=request_end,
duration_ms=int((request_end - request_start) * 1000),
))
if response.status_code != 200:
result.response_text = f"Error: {response.status_code}"
result.fallback_reason_code = f"http_error_{response.status_code}"
return result
response_data = response.json()
except Exception as e:
result.response_text = f"Exception: {str(e)}"
result.fallback_reason_code = "request_exception"
result.total_duration_ms = int((time.time() - overall_start) * 1000)
return result
overall_end = time.time()
result.total_duration_ms = int((overall_end - overall_start) * 1000)
result.timing_records = timing_records
try:
dialogue_response = DialogueResponse(**response_data)
if dialogue_response.segments:
result.response_text = "\n".join(s.text for s in dialogue_response.segments)
trace = dialogue_response.trace
result.raw_trace = trace.model_dump() if trace else None
if trace:
result.execution_mode = trace.mode
result.intent = trace.intent
result.react_iterations = trace.react_iterations or 0
result.tools_used = trace.tools_used or []
result.kb_tool_called = trace.kb_tool_called or False
result.kb_hit = trace.kb_hit or False
result.guardrail_triggered = trace.guardrail_triggered or False
result.fallback_reason_code = trace.fallback_reason_code
if trace.tool_calls:
result.tool_calls = [tc.model_dump() for tc in trace.tool_calls]
for tc in trace.tool_calls:
if tc.tool_name == "kb_search_dynamic":
result.kb_tool_called = True
if tc.arguments:
result.kb_query = tc.arguments.get("query")
result.kb_filter = tc.arguments.get("context")
if tc.result and isinstance(tc.result, dict):
result.kb_hits_count = len(tc.result.get("hits", []))
result.kb_hit = result.kb_hits_count > 0
if trace.duration_ms:
timing_records.append(TimingRecord(
stage="server_processing",
start_time=overall_start,
end_time=overall_end,
duration_ms=trace.duration_ms,
))
if trace.scene:
result.prompt_template_scene = trace.scene
except Exception as e:
logger.error(f"Failed to parse response: {e}")
result.response_text = str(response_data)
return result
def print_result(self, result: DialogueTestResult):
"""打印测试结果"""
print("\n" + "=" * 80)
print(f"[对话测试结果] Request ID: {result.request_id}")
print("=" * 80)
print(f"\n[用户消息] {result.user_message}")
print(f"[回复内容] {result.response_text[:200]}...")
print(f"\n[执行模式] {result.execution_mode.value}")
print(f"[意图识别] intent={result.intent}, confidence={result.confidence}")
print(f"\n[耗时统计] 总耗时: {result.total_duration_ms}ms")
for tr in result.timing_records:
print(f" - {tr.stage}: {tr.duration_ms}ms")
if result.execution_mode == ExecutionMode.AGENT:
print(f"\n[ReAct模式]")
print(f" - 迭代次数: {result.react_iterations}")
print(f" - 使用的工具: {result.tools_used}")
if result.tool_calls:
print(f"\n[工具调用详情]")
for i, tc in enumerate(result.tool_calls, 1):
print(f" [{i}] 工具: {tc.get('tool_name')}")
print(f" 状态: {tc.get('status')}")
print(f" 耗时: {tc.get('duration_ms')}ms")
if tc.get('arguments'):
print(f" 入参: {json.dumps(tc.get('arguments'), ensure_ascii=False)[:200]}")
if tc.get('result'):
result_str = str(tc.get('result'))[:300]
print(f" 结果: {result_str}")
print(f"\n[知识库查询]")
print(f" - 是否调用: {result.kb_tool_called}")
print(f" - 是否命中: {result.kb_hit}")
if result.kb_query:
print(f" - 查询内容: {result.kb_query}")
if result.kb_filter:
print(f" - 过滤条件: {json.dumps(result.kb_filter, ensure_ascii=False)[:200]}")
print(f" - 命中数量: {result.kb_hits_count}")
print(f"\n[提示词模板]")
print(f" - 场景: {result.prompt_template_scene}")
print(f" - 使用模板: {result.prompt_template_used or '默认模板'}")
print(f"\n[其他信息]")
print(f" - 护栏触发: {result.guardrail_triggered}")
print(f" - 降级原因: {result.fallback_reason_code or ''}")
print("\n" + "=" * 80)
class TestMidDialogueIntegration:
"""中台对话集成测试"""
@pytest.fixture
def tester(self):
return DialogueIntegrationTester(
base_url="http://localhost:8000",
tenant_id="test_tenant",
)
@pytest.fixture
def mock_llm_client(self):
"""模拟 LLM 客户端"""
mock = AsyncMock()
mock.generate = AsyncMock(return_value=MagicMock(
content="这是测试回复",
has_tool_calls=False,
tool_calls=[],
))
return mock
@pytest.fixture
def mock_kb_tool(self):
"""模拟知识库工具"""
mock = AsyncMock()
mock.execute = AsyncMock(return_value=MagicMock(
success=True,
hits=[
{"id": "1", "content": "测试知识库内容", "score": 0.9},
],
applied_filter={"scene": "test"},
missing_required_slots=[],
fallback_reason_code=None,
duration_ms=100,
tool_trace=None,
))
return mock
@pytest.mark.asyncio
async def test_simple_greeting(self, tester: DialogueIntegrationTester):
"""测试简单问候"""
result = await tester.send_dialogue(
user_message="你好",
)
tester.print_result(result)
assert result.request_id is not None
assert result.total_duration_ms > 0
@pytest.mark.asyncio
async def test_kb_query(self, tester: DialogueIntegrationTester):
"""测试知识库查询"""
result = await tester.send_dialogue(
user_message="退款流程是什么?",
scene="after_sale",
)
tester.print_result(result)
assert result.request_id is not None
@pytest.mark.asyncio
async def test_high_risk_scenario(self, tester: DialogueIntegrationTester):
"""测试高风险场景"""
result = await tester.send_dialogue(
user_message="我要投诉你们的服务",
)
tester.print_result(result)
assert result.request_id is not None
@pytest.mark.asyncio
async def test_transfer_request(self, tester: DialogueIntegrationTester):
"""测试转人工请求"""
result = await tester.send_dialogue(
user_message="帮我转人工客服",
)
tester.print_result(result)
assert result.request_id is not None
@pytest.mark.asyncio
async def test_with_history(self, tester: DialogueIntegrationTester):
"""测试带历史记录的对话"""
result = await tester.send_dialogue(
user_message="那退款要多久呢?",
history=[
{"role": "user", "content": "我想退款"},
{"role": "assistant", "content": "好的,请问您要退款的订单号是多少?"},
],
)
tester.print_result(result)
assert result.request_id is not None
class TestDialogueWithMock:
"""使用 Mock 的对话测试"""
@pytest.fixture
def mock_app(self):
"""创建带 Mock 的测试应用"""
from fastapi import FastAPI
from app.api.mid.dialogue import router
app = FastAPI()
app.include_router(router)
return app
@pytest.fixture
def client(self, mock_app):
return TestClient(mock_app)
@pytest.fixture
def mock_session(self):
"""模拟数据库会话"""
mock = AsyncMock()
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = []
mock.execute.return_value = mock_result
return mock
@pytest.fixture
def mock_llm(self):
"""模拟 LLM 响应"""
mock_response = MagicMock()
mock_response.content = "这是测试回复内容"
mock_response.has_tool_calls = False
mock_response.tool_calls = []
return mock_response
def test_dialogue_request_structure(self, client: TestClient):
"""测试对话请求结构"""
request_body = {
"session_id": "test_session_001",
"user_id": "test_user_001",
"user_message": "你好",
"history": [],
"scene": "open_consult",
}
print("\n[测试请求结构]")
print(f"Request Body: {json.dumps(request_body, ensure_ascii=False, indent=2)}")
assert request_body["session_id"] == "test_session_001"
assert request_body["user_message"] == "你好"
def test_trace_info_structure(self):
"""测试追踪信息结构"""
trace = TraceInfo(
mode=ExecutionMode.AGENT,
intent="greeting",
request_id=str(uuid.uuid4()),
generation_id=str(uuid.uuid4()),
kb_tool_called=True,
kb_hit=True,
react_iterations=2,
tools_used=["kb_search_dynamic"],
tool_calls=[
ToolCallTrace(
tool_name="kb_search_dynamic",
tool_type=ToolType.INTERNAL,
duration_ms=150,
status=ToolCallStatus.OK,
arguments={"query": "测试查询", "scene": "test"},
result={"hits": [{"content": "测试内容"}]},
),
],
)
print("\n[TraceInfo 结构测试]")
print(f"Mode: {trace.mode.value}")
print(f"Intent: {trace.intent}")
print(f"KB Tool Called: {trace.kb_tool_called}")
print(f"KB Hit: {trace.kb_hit}")
print(f"React Iterations: {trace.react_iterations}")
print(f"Tools Used: {trace.tools_used}")
if trace.tool_calls:
print(f"\n[Tool Calls]")
for tc in trace.tool_calls:
print(f" - Tool: {tc.tool_name}")
print(f" Status: {tc.status.value}")
print(f" Duration: {tc.duration_ms}ms")
if tc.arguments:
print(f" Arguments: {json.dumps(tc.arguments, ensure_ascii=False)}")
assert trace.mode == ExecutionMode.AGENT
assert trace.kb_tool_called is True
assert len(trace.tool_calls) == 1
def test_intent_hint_output_structure(self):
"""测试意图提示输出结构"""
hint = IntentHintOutput(
intent="refund",
confidence=0.85,
response_type="flow",
suggested_mode=ExecutionMode.MICRO_FLOW,
target_flow_id="flow_refund_001",
high_risk_detected=False,
duration_ms=50,
)
print("\n[IntentHintOutput 结构测试]")
print(f"Intent: {hint.intent}")
print(f"Confidence: {hint.confidence}")
print(f"Response Type: {hint.response_type}")
print(f"Suggested Mode: {hint.suggested_mode.value if hint.suggested_mode else None}")
print(f"High Risk Detected: {hint.high_risk_detected}")
print(f"Duration: {hint.duration_ms}ms")
assert hint.intent == "refund"
assert hint.confidence == 0.85
assert hint.suggested_mode == ExecutionMode.MICRO_FLOW
def test_high_risk_check_result_structure(self):
"""测试高风险检测结果结构"""
result = HighRiskCheckResult(
matched=True,
risk_scenario=HighRiskScenario.REFUND,
confidence=0.95,
recommended_mode=ExecutionMode.MICRO_FLOW,
rule_id="rule_refund_001",
reason="检测到退款关键词",
duration_ms=30,
)
print("\n[HighRiskCheckResult 结构测试]")
print(f"Matched: {result.matched}")
print(f"Risk Scenario: {result.risk_scenario.value if result.risk_scenario else None}")
print(f"Confidence: {result.confidence}")
print(f"Recommended Mode: {result.recommended_mode.value if result.recommended_mode else None}")
print(f"Rule ID: {result.rule_id}")
print(f"Duration: {result.duration_ms}ms")
assert result.matched is True
assert result.risk_scenario == HighRiskScenario.REFUND
class TestPromptTemplateUsage:
"""提示词模板使用测试"""
def test_template_resolution(self):
"""测试模板解析"""
from app.services.prompt.variable_resolver import VariableResolver
resolver = VariableResolver()
template = "你好,{{user_name}}!我是{{bot_name}},很高兴为您服务。"
variables = [
{"key": "user_name", "value": "张三"},
{"key": "bot_name", "value": "智能客服"},
]
resolved = resolver.resolve(template, variables)
print("\n[模板解析测试]")
print(f"原始模板: {template}")
print(f"变量: {json.dumps(variables, ensure_ascii=False)}")
print(f"解析结果: {resolved}")
assert resolved == "你好,张三!我是智能客服,很高兴为您服务。"
def test_template_with_extra_context(self):
"""测试带额外上下文的模板解析"""
from app.services.prompt.variable_resolver import VariableResolver
resolver = VariableResolver()
template = "当前场景:{{scene}},用户问题:{{query}}"
extra_context = {
"scene": "售后服务",
"query": "退款流程",
}
resolved = resolver.resolve(template, [], extra_context)
print("\n[带上下文的模板解析测试]")
print(f"原始模板: {template}")
print(f"额外上下文: {json.dumps(extra_context, ensure_ascii=False)}")
print(f"解析结果: {resolved}")
assert "售后服务" in resolved
assert "退款流程" in resolved
class TestToolCallRecording:
"""工具调用记录测试"""
def test_tool_call_trace_creation(self):
"""测试工具调用追踪创建"""
trace = ToolCallTrace(
tool_name="kb_search_dynamic",
tool_type=ToolType.INTERNAL,
duration_ms=150,
status=ToolCallStatus.OK,
args_digest="query=退款流程",
result_digest="hits=3",
arguments={
"query": "退款流程是什么",
"scene": "after_sale",
"context": {"product_type": "vip"},
},
result={
"success": True,
"hits": [
{"id": "1", "content": "退款流程说明...", "score": 0.95},
{"id": "2", "content": "退款注意事项...", "score": 0.88},
{"id": "3", "content": "退款时效说明...", "score": 0.82},
],
"applied_filter": {"product_type": "vip"},
},
)
print("\n[工具调用追踪测试]")
print(f"Tool Name: {trace.tool_name}")
print(f"Tool Type: {trace.tool_type.value}")
print(f"Status: {trace.status.value}")
print(f"Duration: {trace.duration_ms}ms")
print(f"\n[入参详情]")
if trace.arguments:
for key, value in trace.arguments.items():
print(f" - {key}: {value}")
print(f"\n[返回结果]")
if trace.result:
if isinstance(trace.result, dict):
print(f" - success: {trace.result.get('success')}")
print(f" - hits count: {len(trace.result.get('hits', []))}")
for i, hit in enumerate(trace.result.get('hits', [])[:2], 1):
print(f" - hit[{i}]: score={hit.get('score')}, content={hit.get('content')[:30]}...")
assert trace.tool_name == "kb_search_dynamic"
assert trace.status == ToolCallStatus.OK
assert trace.arguments is not None
assert trace.result is not None
def test_tool_call_timeout_trace(self):
"""测试工具调用超时追踪"""
trace = ToolCallTrace(
tool_name="kb_search_dynamic",
tool_type=ToolType.INTERNAL,
duration_ms=2000,
status=ToolCallStatus.TIMEOUT,
error_code="TOOL_TIMEOUT",
arguments={"query": "测试查询"},
)
print("\n[工具调用超时追踪测试]")
print(f"Tool Name: {trace.tool_name}")
print(f"Status: {trace.status.value}")
print(f"Error Code: {trace.error_code}")
print(f"Duration: {trace.duration_ms}ms")
assert trace.status == ToolCallStatus.TIMEOUT
assert trace.error_code == "TOOL_TIMEOUT"
class TestExecutionModeRouting:
"""执行模式路由测试"""
def test_policy_router_decision(self):
"""测试策略路由器决策"""
from app.services.mid.policy_router import PolicyRouter, IntentMatch
router = PolicyRouter()
test_cases = [
{
"name": "正常对话 -> Agent模式",
"user_message": "你好,请问有什么可以帮助我的?",
"session_mode": "BOT_ACTIVE",
"expected_mode": ExecutionMode.AGENT,
},
{
"name": "高风险退款 -> Micro Flow模式",
"user_message": "我要退款",
"session_mode": "BOT_ACTIVE",
"expected_mode": ExecutionMode.MICRO_FLOW,
},
{
"name": "转人工请求 -> Transfer模式",
"user_message": "帮我转人工",
"session_mode": "BOT_ACTIVE",
"expected_mode": ExecutionMode.TRANSFER,
},
{
"name": "人工模式 -> Transfer模式",
"user_message": "你好",
"session_mode": "HUMAN_ACTIVE",
"expected_mode": ExecutionMode.TRANSFER,
},
]
print("\n[策略路由器决策测试]")
for tc in test_cases:
result = router.route(
user_message=tc["user_message"],
session_mode=tc["session_mode"],
)
print(f"\n测试用例: {tc['name']}")
print(f" 用户消息: {tc['user_message']}")
print(f" 会话模式: {tc['session_mode']}")
print(f" 期望模式: {tc['expected_mode'].value}")
print(f" 实际模式: {result.mode.value}")
assert result.mode == tc["expected_mode"], f"模式不匹配: {result.mode} != {tc['expected_mode']}"
class TestKBSearchDynamic:
"""知识库动态检索测试"""
def test_kb_search_result_structure(self):
"""测试知识库检索结果结构"""
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicResult
result = KbSearchDynamicResult(
success=True,
hits=[
{
"id": "chunk_001",
"content": "退款流程1. 登录账户 2. 进入订单页面 3. 点击退款按钮...",
"score": 0.92,
"metadata": {"kb_id": "kb_001", "doc_id": "doc_001"},
},
{
"id": "chunk_002",
"content": "退款时效一般3-5个工作日到账...",
"score": 0.85,
"metadata": {"kb_id": "kb_001", "doc_id": "doc_002"},
},
],
applied_filter={"scene": "after_sale", "product_type": "vip"},
missing_required_slots=[],
filter_debug={"source": "slot_state"},
filter_sources={"scene": "slot", "product_type": "context"},
duration_ms=120,
)
print("\n[知识库检索结果结构测试]")
print(f"Success: {result.success}")
print(f"Hits Count: {len(result.hits)}")
print(f"Applied Filter: {json.dumps(result.applied_filter, ensure_ascii=False)}")
print(f"Filter Sources: {json.dumps(result.filter_sources, ensure_ascii=False)}")
print(f"Duration: {result.duration_ms}ms")
print(f"\n[命中详情]")
for i, hit in enumerate(result.hits, 1):
print(f" [{i}] ID: {hit['id']}")
print(f" Score: {hit['score']}")
print(f" Content: {hit['content'][:50]}...")
assert result.success is True
assert len(result.hits) == 2
assert result.applied_filter is not None
def test_kb_search_missing_slots(self):
"""测试知识库检索缺失槽位"""
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicResult
result = KbSearchDynamicResult(
success=False,
hits=[],
applied_filter={},
missing_required_slots=[
{
"field_key": "order_id",
"label": "订单号",
"reason": "必填字段缺失",
"ask_back_prompt": "请提供您的订单号",
},
],
filter_debug={"source": "builder"},
fallback_reason_code="MISSING_REQUIRED_SLOTS",
duration_ms=50,
)
print("\n[知识库检索缺失槽位测试]")
print(f"Success: {result.success}")
print(f"Fallback Reason: {result.fallback_reason_code}")
print(f"Missing Slots: {json.dumps(result.missing_required_slots, ensure_ascii=False, indent=2)}")
assert result.success is False
assert result.fallback_reason_code == "MISSING_REQUIRED_SLOTS"
assert len(result.missing_required_slots) == 1
class TestTimingBreakdown:
"""耗时分解测试"""
def test_timing_breakdown_structure(self):
"""测试耗时分解结构"""
timings = [
TimingRecord("intent_matching", 0, 0.05, 50),
TimingRecord("high_risk_check", 0.05, 0.08, 30),
TimingRecord("kb_search", 0.08, 0.2, 120),
TimingRecord("llm_generation", 0.2, 1.5, 1300),
TimingRecord("output_guardrail", 1.5, 1.55, 50),
TimingRecord("response_formatting", 1.55, 1.6, 50),
]
total = sum(t.duration_ms for t in timings)
print("\n[耗时分解测试]")
print(f"{'阶段':<25} {'耗时(ms)':<10} {'占比':<10}")
print("-" * 45)
for t in timings:
percentage = (t.duration_ms / total * 100) if total > 0 else 0
print(f"{t.stage:<25} {t.duration_ms:<10} {percentage:.1f}%")
print("-" * 45)
print(f"{'总计':<25} {total:<10} {'100.0%':<10}")
assert total == 1600
def run_manual_test():
"""手动运行测试"""
import argparse
parser = argparse.ArgumentParser(description="中台对话集成测试")
parser.add_argument("--url", default="http://localhost:8000", help="服务地址")
parser.add_argument("--tenant", default="test_tenant", help="租户ID")
parser.add_argument("--api-key", default=None, help="API Key")
parser.add_argument("--message", default="你好", help="测试消息")
parser.add_argument("--scene", default=None, help="场景标识")
parser.add_argument("--interactive", action="store_true", help="交互模式")
args = parser.parse_args()
tester = DialogueIntegrationTester(
base_url=args.url,
tenant_id=args.tenant,
api_key=args.api_key,
)
if args.interactive:
print("\n=== 中台对话集成测试 - 交互模式 ===")
print("输入 'quit' 退出\n")
while True:
try:
message = input("请输入消息: ").strip()
if message.lower() == "quit":
break
scene = input("请输入场景(可选,直接回车跳过): ").strip() or None
result = asyncio.run(tester.send_dialogue(
user_message=message,
scene=scene,
))
tester.print_result(result)
except KeyboardInterrupt:
break
else:
result = asyncio.run(tester.send_dialogue(
user_message=args.message,
scene=args.scene,
))
tester.print_result(result)
if __name__ == "__main__":
run_manual_test()