1076 lines
36 KiB
Python
1076 lines
36 KiB
Python
"""
|
|
Dialogue Controller for Mid Platform.
|
|
[AC-MARH-01, AC-MARH-02, AC-MARH-03, AC-MARH-04, AC-MARH-05, AC-MARH-06,
|
|
AC-MARH-07, AC-MARH-08, AC-MARH-09, AC-MARH-10, AC-MARH-11, AC-MARH-12]
|
|
|
|
Core endpoint: POST /mid/dialogue/respond
|
|
"""
|
|
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from typing import Annotated, Any
|
|
|
|
from fastapi import APIRouter, Depends, Request
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.core.database import get_session
|
|
from app.core.tenant import get_tenant_id
|
|
from app.models.mid.schemas import (
|
|
DialogueRequest,
|
|
DialogueResponse,
|
|
ExecutionMode,
|
|
Segment,
|
|
TraceInfo,
|
|
)
|
|
from app.services.mid.agent_orchestrator import AgentOrchestrator
|
|
from app.services.mid.default_kb_tool_runner import DefaultKbToolRunner
|
|
from app.services.mid.feature_flags import FeatureFlagService
|
|
from app.services.mid.high_risk_handler import HighRiskHandler
|
|
from app.services.mid.interrupt_context_enricher import InterruptContextEnricher
|
|
from app.services.mid.kb_search_dynamic_tool import (
|
|
KbSearchDynamicConfig,
|
|
KbSearchDynamicTool,
|
|
)
|
|
from app.services.mid.high_risk_check_tool import (
|
|
HighRiskCheckConfig,
|
|
HighRiskCheckTool,
|
|
register_high_risk_check_tool,
|
|
)
|
|
from app.services.mid.intent_hint_tool import (
|
|
IntentHintConfig,
|
|
IntentHintTool,
|
|
register_intent_hint_tool,
|
|
)
|
|
from app.services.mid.memory_recall_tool import (
|
|
MemoryRecallConfig,
|
|
MemoryRecallTool,
|
|
register_memory_recall_tool,
|
|
)
|
|
from app.services.mid.metrics_collector import MetricsCollector
|
|
from app.services.mid.output_guardrail_executor import OutputGuardrailExecutor
|
|
from app.services.mid.policy_router import IntentMatch, PolicyRouter
|
|
from app.services.mid.runtime_observer import RuntimeObserver
|
|
from app.services.mid.segment_humanizer import HumanizeConfig, SegmentHumanizer
|
|
from app.services.mid.timeout_governor import TimeoutGovernor
|
|
from app.services.mid.tool_registry import ToolRegistry
|
|
from app.services.mid.trace_logger import TraceLogger
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
router = APIRouter(prefix="/mid", tags=["Mid Platform Dialogue"])
|
|
|
|
_mid_services: dict[str, Any] = {}
|
|
|
|
|
|
def get_policy_router() -> PolicyRouter:
|
|
"""Get or create PolicyRouter instance."""
|
|
if "policy_router" not in _mid_services:
|
|
_mid_services["policy_router"] = PolicyRouter()
|
|
return _mid_services["policy_router"]
|
|
|
|
|
|
def get_high_risk_handler() -> HighRiskHandler:
|
|
"""Get or create HighRiskHandler instance."""
|
|
if "high_risk_handler" not in _mid_services:
|
|
_mid_services["high_risk_handler"] = HighRiskHandler()
|
|
return _mid_services["high_risk_handler"]
|
|
|
|
|
|
def get_timeout_governor() -> TimeoutGovernor:
|
|
"""Get or create TimeoutGovernor instance."""
|
|
if "timeout_governor" not in _mid_services:
|
|
_mid_services["timeout_governor"] = TimeoutGovernor()
|
|
return _mid_services["timeout_governor"]
|
|
|
|
|
|
def get_feature_flag_service() -> FeatureFlagService:
|
|
"""Get or create FeatureFlagService instance."""
|
|
if "feature_flag_service" not in _mid_services:
|
|
_mid_services["feature_flag_service"] = FeatureFlagService()
|
|
return _mid_services["feature_flag_service"]
|
|
|
|
|
|
def get_trace_logger() -> TraceLogger:
|
|
"""Get or create TraceLogger instance."""
|
|
if "trace_logger" not in _mid_services:
|
|
_mid_services["trace_logger"] = TraceLogger()
|
|
return _mid_services["trace_logger"]
|
|
|
|
|
|
def get_metrics_collector() -> MetricsCollector:
|
|
"""Get or create MetricsCollector instance."""
|
|
if "metrics_collector" not in _mid_services:
|
|
_mid_services["metrics_collector"] = MetricsCollector()
|
|
return _mid_services["metrics_collector"]
|
|
|
|
|
|
def get_tool_registry() -> ToolRegistry:
|
|
"""Get or create ToolRegistry instance."""
|
|
if "tool_registry" not in _mid_services:
|
|
_mid_services["tool_registry"] = ToolRegistry(
|
|
timeout_governor=get_timeout_governor()
|
|
)
|
|
return _mid_services["tool_registry"]
|
|
|
|
|
|
_kb_search_dynamic_registered: bool = False
|
|
_intent_hint_registered: bool = False
|
|
_high_risk_check_registered: bool = False
|
|
_memory_recall_registered: bool = False
|
|
|
|
|
|
def ensure_kb_search_dynamic_registered(
|
|
registry: ToolRegistry,
|
|
session: AsyncSession,
|
|
) -> None:
|
|
"""[AC-MARH-05] Ensure kb_search_dynamic tool is registered."""
|
|
global _kb_search_dynamic_registered
|
|
if _kb_search_dynamic_registered:
|
|
return
|
|
|
|
from app.services.mid.kb_search_dynamic_tool import register_kb_search_dynamic_tool
|
|
|
|
config = KbSearchDynamicConfig(
|
|
enabled=True,
|
|
top_k=5,
|
|
timeout_ms=2000,
|
|
min_score_threshold=0.5,
|
|
)
|
|
|
|
register_kb_search_dynamic_tool(
|
|
registry=registry,
|
|
session=session,
|
|
timeout_governor=get_timeout_governor(),
|
|
config=config,
|
|
)
|
|
_kb_search_dynamic_registered = True
|
|
logger.info("[AC-MARH-05] kb_search_dynamic tool registered to registry")
|
|
|
|
|
|
def ensure_intent_hint_registered(
|
|
registry: ToolRegistry,
|
|
session: AsyncSession,
|
|
) -> None:
|
|
"""[AC-IDMP-02, AC-IDMP-16] Ensure intent_hint tool is registered."""
|
|
global _intent_hint_registered
|
|
if _intent_hint_registered:
|
|
return
|
|
|
|
config = IntentHintConfig(
|
|
enabled=True,
|
|
timeout_ms=500,
|
|
top_n=3,
|
|
low_confidence_threshold=0.3,
|
|
)
|
|
|
|
register_intent_hint_tool(
|
|
registry=registry,
|
|
session=session,
|
|
config=config,
|
|
)
|
|
_intent_hint_registered = True
|
|
logger.info("[AC-IDMP-02] intent_hint tool registered to registry")
|
|
|
|
|
|
def ensure_high_risk_check_registered(
|
|
registry: ToolRegistry,
|
|
session: AsyncSession,
|
|
) -> None:
|
|
"""[AC-IDMP-05, AC-IDMP-20] Ensure high_risk_check tool is registered."""
|
|
global _high_risk_check_registered
|
|
if _high_risk_check_registered:
|
|
return
|
|
|
|
config = HighRiskCheckConfig(
|
|
enabled=True,
|
|
timeout_ms=500,
|
|
default_confidence=0.9,
|
|
)
|
|
|
|
register_high_risk_check_tool(
|
|
registry=registry,
|
|
session=session,
|
|
config=config,
|
|
)
|
|
_high_risk_check_registered = True
|
|
logger.info("[AC-IDMP-05] high_risk_check tool registered to registry")
|
|
|
|
|
|
def ensure_memory_recall_registered(
|
|
registry: ToolRegistry,
|
|
session: AsyncSession,
|
|
) -> None:
|
|
"""[AC-IDMP-13] Ensure memory_recall tool is registered."""
|
|
global _memory_recall_registered
|
|
if _memory_recall_registered:
|
|
return
|
|
|
|
config = MemoryRecallConfig(
|
|
enabled=True,
|
|
timeout_ms=1000,
|
|
max_recent_messages=8,
|
|
)
|
|
|
|
register_memory_recall_tool(
|
|
registry=registry,
|
|
session=session,
|
|
timeout_governor=get_timeout_governor(),
|
|
config=config,
|
|
)
|
|
_memory_recall_registered = True
|
|
logger.info("[AC-IDMP-13] memory_recall tool registered to registry")
|
|
|
|
|
|
def get_output_guardrail_executor() -> OutputGuardrailExecutor:
|
|
"""Get or create OutputGuardrailExecutor instance."""
|
|
if "output_guardrail_executor" not in _mid_services:
|
|
_mid_services["output_guardrail_executor"] = OutputGuardrailExecutor()
|
|
return _mid_services["output_guardrail_executor"]
|
|
|
|
|
|
def get_interrupt_context_enricher() -> InterruptContextEnricher:
|
|
"""Get or create InterruptContextEnricher instance."""
|
|
if "interrupt_context_enricher" not in _mid_services:
|
|
_mid_services["interrupt_context_enricher"] = InterruptContextEnricher()
|
|
return _mid_services["interrupt_context_enricher"]
|
|
|
|
|
|
def get_default_kb_tool_runner() -> DefaultKbToolRunner:
|
|
"""Get or create DefaultKbToolRunner instance."""
|
|
if "default_kb_tool_runner" not in _mid_services:
|
|
_mid_services["default_kb_tool_runner"] = DefaultKbToolRunner(
|
|
timeout_governor=get_timeout_governor()
|
|
)
|
|
return _mid_services["default_kb_tool_runner"]
|
|
|
|
|
|
def get_segment_humanizer() -> SegmentHumanizer:
|
|
"""Get or create SegmentHumanizer instance."""
|
|
if "segment_humanizer" not in _mid_services:
|
|
_mid_services["segment_humanizer"] = SegmentHumanizer()
|
|
return _mid_services["segment_humanizer"]
|
|
|
|
|
|
def get_runtime_observer() -> RuntimeObserver:
|
|
"""Get or create RuntimeObserver instance."""
|
|
if "runtime_observer" not in _mid_services:
|
|
_mid_services["runtime_observer"] = RuntimeObserver()
|
|
return _mid_services["runtime_observer"]
|
|
|
|
|
|
@router.post(
|
|
"/dialogue/respond",
|
|
operation_id="respondDialogue",
|
|
summary="Generate mid platform dialogue response",
|
|
description="""
|
|
[AC-MARH-01~12] Core dialogue response endpoint for mid platform.
|
|
|
|
Returns segments[] with trace info including:
|
|
- guardrail_triggered, guardrail_rule_id
|
|
- interrupt_consumed
|
|
- kb_tool_called, kb_hit
|
|
- timeout_profile, segment_stats
|
|
""",
|
|
)
|
|
async def respond_dialogue(
|
|
request: Request,
|
|
dialogue_request: DialogueRequest,
|
|
session: Annotated[AsyncSession, Depends(get_session)],
|
|
policy_router: PolicyRouter = Depends(get_policy_router),
|
|
high_risk_handler: HighRiskHandler = Depends(get_high_risk_handler),
|
|
timeout_governor: TimeoutGovernor = Depends(get_timeout_governor),
|
|
feature_flag_service: FeatureFlagService = Depends(get_feature_flag_service),
|
|
trace_logger: TraceLogger = Depends(get_trace_logger),
|
|
metrics_collector: MetricsCollector = Depends(get_metrics_collector),
|
|
output_guardrail_executor: OutputGuardrailExecutor = Depends(get_output_guardrail_executor),
|
|
interrupt_context_enricher: InterruptContextEnricher = Depends(get_interrupt_context_enricher),
|
|
default_kb_tool_runner: DefaultKbToolRunner = Depends(get_default_kb_tool_runner),
|
|
segment_humanizer: SegmentHumanizer = Depends(get_segment_humanizer),
|
|
runtime_observer: RuntimeObserver = Depends(get_runtime_observer),
|
|
) -> DialogueResponse:
|
|
"""
|
|
[AC-MARH-01~12] Generate dialogue response with segments and trace.
|
|
|
|
Flow:
|
|
1. Validate request and get tenant context
|
|
2. Start runtime observation
|
|
3. Process interrupted segments (AC-MARH-03/04)
|
|
4. Check feature flags for grayscale/rollback
|
|
5. Detect high-risk scenarios
|
|
6. Route to appropriate execution mode
|
|
7. For Agent mode: call KB tool (AC-MARH-05/06)
|
|
8. Execute output guardrail (AC-MARH-01/02)
|
|
9. Apply segment humanizer (AC-MARH-10/11)
|
|
10. Collect trace and return (AC-MARH-12)
|
|
"""
|
|
start_time = time.time()
|
|
tenant_id = get_tenant_id()
|
|
|
|
if not tenant_id:
|
|
from app.core.exceptions import MissingTenantIdException
|
|
raise MissingTenantIdException()
|
|
|
|
request_id = str(uuid.uuid4())
|
|
generation_id = str(uuid.uuid4())
|
|
|
|
logger.info(
|
|
f"[AC-MARH-01] Dialogue request: tenant={tenant_id}, "
|
|
f"session={dialogue_request.session_id}, request_id={request_id}"
|
|
)
|
|
|
|
runtime_ctx = runtime_observer.start_observation(
|
|
tenant_id=tenant_id,
|
|
session_id=dialogue_request.session_id,
|
|
request_id=request_id,
|
|
generation_id=generation_id,
|
|
)
|
|
|
|
trace = trace_logger.start_trace(
|
|
tenant_id=tenant_id,
|
|
session_id=dialogue_request.session_id,
|
|
request_id=request_id,
|
|
generation_id=generation_id,
|
|
)
|
|
|
|
metrics_collector.start_session(dialogue_request.session_id)
|
|
|
|
tool_registry = get_tool_registry()
|
|
ensure_kb_search_dynamic_registered(tool_registry, session)
|
|
ensure_intent_hint_registered(tool_registry, session)
|
|
ensure_high_risk_check_registered(tool_registry, session)
|
|
ensure_memory_recall_registered(tool_registry, session)
|
|
|
|
try:
|
|
interrupt_ctx = interrupt_context_enricher.enrich(
|
|
dialogue_request.interrupted_segments,
|
|
generation_id,
|
|
)
|
|
runtime_observer.record_interrupt(request_id, interrupt_ctx.consumed)
|
|
|
|
feature_flags = dialogue_request.feature_flags or feature_flag_service.get_flags(
|
|
dialogue_request.session_id
|
|
)
|
|
|
|
if feature_flags.rollback_to_legacy:
|
|
logger.info(f"[AC-MARH-17] Rollback to legacy for session: {dialogue_request.session_id}")
|
|
return await _handle_legacy_response(
|
|
tenant_id=tenant_id,
|
|
request=dialogue_request,
|
|
trace=trace,
|
|
trace_logger=trace_logger,
|
|
start_time=start_time,
|
|
)
|
|
|
|
high_risk_check_tool = HighRiskCheckTool(
|
|
session=session,
|
|
config=HighRiskCheckConfig(enabled=True, timeout_ms=500),
|
|
)
|
|
high_risk_result = await high_risk_check_tool.execute(
|
|
message=dialogue_request.user_message,
|
|
tenant_id=tenant_id,
|
|
)
|
|
|
|
if high_risk_result.duration_ms > 0:
|
|
hr_trace = high_risk_check_tool.create_trace(high_risk_result, tenant_id)
|
|
trace_logger.update_trace(
|
|
request_id=request_id,
|
|
tool_calls=[hr_trace],
|
|
)
|
|
|
|
logger.info(
|
|
f"[AC-IDMP-05, AC-IDMP-20] High risk check result: "
|
|
f"matched={high_risk_result.matched}, scenario={high_risk_result.risk_scenario}, "
|
|
f"duration_ms={high_risk_result.duration_ms}"
|
|
)
|
|
|
|
if high_risk_result.matched and high_risk_result.risk_scenario:
|
|
logger.info(
|
|
f"[AC-IDMP-05] High-risk matched from tool: {high_risk_result.risk_scenario.value}"
|
|
)
|
|
return await _handle_high_risk_check_response(
|
|
tenant_id=tenant_id,
|
|
request=dialogue_request,
|
|
high_risk_result=high_risk_result,
|
|
trace=trace,
|
|
trace_logger=trace_logger,
|
|
start_time=start_time,
|
|
session=session,
|
|
)
|
|
|
|
intent_hint_tool = IntentHintTool(
|
|
session=session,
|
|
config=IntentHintConfig(enabled=True, timeout_ms=500),
|
|
)
|
|
intent_hint = await intent_hint_tool.execute(
|
|
message=dialogue_request.user_message,
|
|
tenant_id=tenant_id,
|
|
history=[h.model_dump() for h in dialogue_request.history] if dialogue_request.history else None,
|
|
)
|
|
|
|
if intent_hint.duration_ms > 0:
|
|
hint_trace = intent_hint_tool.create_trace(intent_hint)
|
|
trace_logger.update_trace(
|
|
request_id=request_id,
|
|
tool_calls=[hint_trace],
|
|
)
|
|
|
|
logger.info(
|
|
f"[AC-IDMP-02] Intent hint result: intent={intent_hint.intent}, "
|
|
f"confidence={intent_hint.confidence}, suggested_mode={intent_hint.suggested_mode}"
|
|
)
|
|
|
|
intent_match = await _match_intent(tenant_id, dialogue_request, session)
|
|
|
|
router_result = policy_router.route(
|
|
user_message=dialogue_request.user_message,
|
|
session_mode="BOT_ACTIVE",
|
|
feature_flags=feature_flags,
|
|
intent_match=intent_match,
|
|
intent_hint=intent_hint,
|
|
)
|
|
|
|
runtime_observer.update_mode(request_id, router_result.mode, router_result.intent)
|
|
runtime_observer.record_timeout_profile(request_id, timeout_governor.profile)
|
|
|
|
trace_logger.update_trace(
|
|
request_id=request_id,
|
|
mode=router_result.mode,
|
|
intent=router_result.intent,
|
|
fallback_reason_code=router_result.fallback_reason_code,
|
|
)
|
|
|
|
if router_result.mode == ExecutionMode.AGENT:
|
|
response = await _execute_agent_mode(
|
|
tenant_id=tenant_id,
|
|
request=dialogue_request,
|
|
request_id=request_id,
|
|
trace=trace,
|
|
trace_logger=trace_logger,
|
|
timeout_governor=timeout_governor,
|
|
metrics_collector=metrics_collector,
|
|
default_kb_tool_runner=default_kb_tool_runner,
|
|
runtime_observer=runtime_observer,
|
|
interrupt_ctx=interrupt_ctx,
|
|
start_time=start_time,
|
|
session=session,
|
|
tool_registry=tool_registry,
|
|
)
|
|
elif router_result.mode == ExecutionMode.MICRO_FLOW:
|
|
response = await _execute_micro_flow_mode(
|
|
tenant_id=tenant_id,
|
|
request=dialogue_request,
|
|
router_result=router_result,
|
|
trace=trace,
|
|
trace_logger=trace_logger,
|
|
session=session,
|
|
start_time=start_time,
|
|
)
|
|
elif router_result.mode == ExecutionMode.FIXED:
|
|
response = await _execute_fixed_mode(
|
|
tenant_id=tenant_id,
|
|
request=dialogue_request,
|
|
router_result=router_result,
|
|
trace=trace,
|
|
trace_logger=trace_logger,
|
|
start_time=start_time,
|
|
)
|
|
else:
|
|
response = await _execute_transfer_mode(
|
|
tenant_id=tenant_id,
|
|
request=dialogue_request,
|
|
router_result=router_result,
|
|
trace=trace,
|
|
trace_logger=trace_logger,
|
|
start_time=start_time,
|
|
)
|
|
|
|
filtered_segments, guardrail_result = await output_guardrail_executor.filter_segments(
|
|
response.segments, tenant_id
|
|
)
|
|
runtime_observer.record_guardrail(
|
|
request_id, guardrail_result.triggered, guardrail_result.rule_id
|
|
)
|
|
|
|
humanize_config = None
|
|
if dialogue_request.humanize_config:
|
|
humanize_config = HumanizeConfig(
|
|
enabled=dialogue_request.humanize_config.enabled or True,
|
|
min_delay_ms=dialogue_request.humanize_config.min_delay_ms or 50,
|
|
max_delay_ms=dialogue_request.humanize_config.max_delay_ms or 500,
|
|
length_bucket_strategy=dialogue_request.humanize_config.length_bucket_strategy or "simple",
|
|
)
|
|
|
|
final_segments, segment_stats = segment_humanizer.humanize(
|
|
"\n\n".join(s.text for s in filtered_segments),
|
|
humanize_config,
|
|
)
|
|
runtime_observer.record_segment_stats(request_id, segment_stats)
|
|
|
|
final_trace = runtime_observer.end_observation(request_id)
|
|
final_trace.segment_stats = segment_stats
|
|
final_trace.guardrail_triggered = guardrail_result.triggered
|
|
final_trace.guardrail_rule_id = guardrail_result.rule_id
|
|
|
|
latency_ms = int((time.time() - start_time) * 1000)
|
|
|
|
metrics_collector.record_turn(
|
|
session_id=dialogue_request.session_id,
|
|
tenant_id=tenant_id,
|
|
latency_ms=latency_ms,
|
|
task_completed=True,
|
|
)
|
|
|
|
audit = trace_logger.end_trace(
|
|
request_id=request_id,
|
|
tenant_id=tenant_id,
|
|
session_id=dialogue_request.session_id,
|
|
latency_ms=latency_ms,
|
|
)
|
|
|
|
logger.info(
|
|
f"[AC-MARH-12] Audit record: request_id={request_id}, "
|
|
f"mode={final_trace.mode.value}, latency_ms={latency_ms}, "
|
|
f"guardrail={guardrail_result.triggered}, kb_hit={final_trace.kb_hit}"
|
|
)
|
|
|
|
return DialogueResponse(
|
|
segments=final_segments,
|
|
trace=final_trace,
|
|
)
|
|
|
|
except Exception as e:
|
|
latency_ms = int((time.time() - start_time) * 1000)
|
|
logger.error(f"[AC-IDMP-06] Dialogue error: {e}")
|
|
|
|
trace_logger.update_trace(
|
|
request_id=request_id,
|
|
mode=ExecutionMode.FIXED,
|
|
fallback_reason_code=f"error: {str(e)[:50]}",
|
|
)
|
|
|
|
trace_logger.end_trace(
|
|
request_id=request_id,
|
|
tenant_id=tenant_id,
|
|
session_id=dialogue_request.session_id,
|
|
latency_ms=latency_ms,
|
|
)
|
|
|
|
return DialogueResponse(
|
|
segments=[Segment(
|
|
text="抱歉,服务暂时不可用,请稍后重试或联系人工客服。",
|
|
delay_after=0,
|
|
)],
|
|
trace=TraceInfo(
|
|
mode=ExecutionMode.FIXED,
|
|
request_id=request_id,
|
|
generation_id=generation_id,
|
|
fallback_reason_code="service_error",
|
|
),
|
|
)
|
|
|
|
|
|
async def _match_intent(
|
|
tenant_id: str,
|
|
request: DialogueRequest,
|
|
session: AsyncSession,
|
|
) -> IntentMatch | None:
|
|
"""Match intent from user message."""
|
|
try:
|
|
from app.services.intent.router import IntentRouter
|
|
from app.services.intent.rule_service import IntentRuleService
|
|
|
|
rule_service = IntentRuleService(session)
|
|
rules = await rule_service.get_enabled_rules_for_matching(tenant_id)
|
|
|
|
if not rules:
|
|
return None
|
|
|
|
router = IntentRouter()
|
|
result = router.match(request.user_message, rules)
|
|
|
|
if result:
|
|
return IntentMatch(
|
|
intent_id=str(result.rule.id),
|
|
intent_name=result.rule.name,
|
|
confidence=0.8,
|
|
response_type=result.rule.response_type,
|
|
target_kb_ids=result.rule.target_kb_ids,
|
|
flow_id=str(result.rule.flow_id) if result.rule.flow_id else None,
|
|
fixed_reply=result.rule.fixed_reply,
|
|
transfer_message=result.rule.transfer_message,
|
|
)
|
|
|
|
return None
|
|
|
|
except Exception as e:
|
|
logger.warning(f"[AC-IDMP-02] Intent match failed: {e}")
|
|
return None
|
|
|
|
|
|
async def _handle_legacy_response(
|
|
tenant_id: str,
|
|
request: DialogueRequest,
|
|
trace: TraceInfo,
|
|
trace_logger: TraceLogger,
|
|
start_time: float,
|
|
) -> DialogueResponse:
|
|
"""Handle rollback to legacy pipeline."""
|
|
latency_ms = int((time.time() - start_time) * 1000)
|
|
|
|
return DialogueResponse(
|
|
segments=[Segment(
|
|
text="正在使用传统模式处理您的请求...",
|
|
delay_after=0,
|
|
)],
|
|
trace=TraceInfo(
|
|
mode=ExecutionMode.FIXED,
|
|
request_id=trace.request_id,
|
|
generation_id=trace.generation_id,
|
|
fallback_reason_code="rollback_to_legacy",
|
|
),
|
|
)
|
|
|
|
|
|
async def _handle_high_risk_check_response(
|
|
tenant_id: str,
|
|
request: DialogueRequest,
|
|
high_risk_result: Any,
|
|
trace: TraceInfo,
|
|
trace_logger: TraceLogger,
|
|
start_time: float,
|
|
session: AsyncSession,
|
|
) -> DialogueResponse:
|
|
"""
|
|
[AC-IDMP-05, AC-IDMP-20] Handle high-risk scenario from high_risk_check tool.
|
|
|
|
高风险优先于普通意图路由。
|
|
"""
|
|
from app.models.mid.schemas import HighRiskCheckResult
|
|
|
|
if not isinstance(high_risk_result, HighRiskCheckResult):
|
|
high_risk_result = HighRiskCheckResult(**high_risk_result)
|
|
|
|
latency_ms = int((time.time() - start_time) * 1000)
|
|
|
|
recommended_mode = high_risk_result.recommended_mode or ExecutionMode.MICRO_FLOW
|
|
risk_scenario = high_risk_result.risk_scenario
|
|
|
|
trace_logger.update_trace(
|
|
request_id=trace.request_id or "",
|
|
mode=recommended_mode,
|
|
fallback_reason_code=f"high_risk_{risk_scenario.value if risk_scenario else 'unknown'}",
|
|
)
|
|
|
|
if recommended_mode == ExecutionMode.TRANSFER:
|
|
transfer_msg = "正在为您转接人工客服..."
|
|
if risk_scenario:
|
|
if risk_scenario.value == "complaint_escalation":
|
|
transfer_msg = "检测到您可能需要投诉处理,正在为您转接人工客服..."
|
|
elif risk_scenario.value == "refund":
|
|
transfer_msg = "您的退款请求需要人工处理,正在为您转接..."
|
|
|
|
return DialogueResponse(
|
|
segments=[Segment(
|
|
text=transfer_msg,
|
|
delay_after=0,
|
|
)],
|
|
trace=TraceInfo(
|
|
mode=ExecutionMode.TRANSFER,
|
|
request_id=trace.request_id,
|
|
generation_id=trace.generation_id,
|
|
high_risk_policy_set=[risk_scenario] if risk_scenario else None,
|
|
fallback_reason_code=high_risk_result.rule_id,
|
|
),
|
|
)
|
|
|
|
if high_risk_result.rule_id:
|
|
try:
|
|
from sqlalchemy import select
|
|
from app.models.entities import HighRiskPolicy
|
|
import uuid
|
|
|
|
stmt = select(HighRiskPolicy).where(
|
|
HighRiskPolicy.id == uuid.UUID(high_risk_result.rule_id)
|
|
)
|
|
result = await session.execute(stmt)
|
|
policy = result.scalar_one_or_none()
|
|
|
|
if policy and policy.flow_id:
|
|
return DialogueResponse(
|
|
segments=[Segment(
|
|
text="检测到您的请求需要特殊处理,正在为您安排...",
|
|
delay_after=0,
|
|
)],
|
|
trace=TraceInfo(
|
|
mode=ExecutionMode.MICRO_FLOW,
|
|
request_id=trace.request_id,
|
|
generation_id=trace.generation_id,
|
|
high_risk_policy_set=[risk_scenario] if risk_scenario else None,
|
|
fallback_reason_code=high_risk_result.rule_id,
|
|
),
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"[AC-IDMP-05] Failed to load high risk policy: {e}")
|
|
|
|
return DialogueResponse(
|
|
segments=[Segment(
|
|
text="检测到您的请求需要特殊处理,正在为您安排...",
|
|
delay_after=0,
|
|
)],
|
|
trace=TraceInfo(
|
|
mode=ExecutionMode.MICRO_FLOW,
|
|
request_id=trace.request_id,
|
|
generation_id=trace.generation_id,
|
|
high_risk_policy_set=[risk_scenario] if risk_scenario else None,
|
|
fallback_reason_code=high_risk_result.rule_id or f"high_risk_{risk_scenario.value if risk_scenario else 'unknown'}",
|
|
),
|
|
)
|
|
|
|
|
|
async def _handle_high_risk_response(
|
|
tenant_id: str,
|
|
request: DialogueRequest,
|
|
high_risk_match: Any,
|
|
high_risk_handler: HighRiskHandler,
|
|
trace: TraceInfo,
|
|
trace_logger: TraceLogger,
|
|
start_time: float,
|
|
) -> DialogueResponse:
|
|
"""Handle high-risk scenario response."""
|
|
router_result = high_risk_handler.handle(high_risk_match)
|
|
|
|
latency_ms = int((time.time() - start_time) * 1000)
|
|
|
|
trace_logger.update_trace(
|
|
request_id=trace.request_id or "",
|
|
mode=router_result.mode,
|
|
fallback_reason_code=router_result.fallback_reason_code,
|
|
)
|
|
|
|
if router_result.mode == ExecutionMode.TRANSFER:
|
|
return DialogueResponse(
|
|
segments=[Segment(
|
|
text=router_result.transfer_message or "正在为您转接人工客服...",
|
|
delay_after=0,
|
|
)],
|
|
trace=TraceInfo(
|
|
mode=ExecutionMode.TRANSFER,
|
|
request_id=trace.request_id,
|
|
generation_id=trace.generation_id,
|
|
high_risk_policy_set=[high_risk_match.scenario],
|
|
fallback_reason_code=router_result.fallback_reason_code,
|
|
),
|
|
)
|
|
|
|
return DialogueResponse(
|
|
segments=[Segment(
|
|
text="检测到您的请求需要特殊处理,正在为您安排...",
|
|
delay_after=0,
|
|
)],
|
|
trace=TraceInfo(
|
|
mode=ExecutionMode.MICRO_FLOW,
|
|
request_id=trace.request_id,
|
|
generation_id=trace.generation_id,
|
|
high_risk_policy_set=[high_risk_match.scenario],
|
|
fallback_reason_code=router_result.fallback_reason_code,
|
|
),
|
|
)
|
|
|
|
|
|
async def _execute_agent_mode(
|
|
tenant_id: str,
|
|
request: DialogueRequest,
|
|
request_id: str,
|
|
trace: TraceInfo,
|
|
trace_logger: TraceLogger,
|
|
timeout_governor: TimeoutGovernor,
|
|
metrics_collector: MetricsCollector,
|
|
default_kb_tool_runner: DefaultKbToolRunner,
|
|
runtime_observer: RuntimeObserver,
|
|
interrupt_ctx: Any = None,
|
|
start_time: float = 0,
|
|
session: AsyncSession | None = None,
|
|
tool_registry: ToolRegistry | None = None,
|
|
) -> DialogueResponse:
|
|
"""[AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-IDMP-13] Execute agent mode with ReAct loop, KB tool, and memory recall."""
|
|
from app.services.llm.factory import get_llm_config_manager
|
|
|
|
try:
|
|
llm_manager = get_llm_config_manager()
|
|
llm_client = llm_manager.get_client()
|
|
except Exception as e:
|
|
logger.warning(f"[AC-MARH-07] Failed to get LLM client: {e}")
|
|
llm_client = None
|
|
|
|
base_context = {"history": [h.model_dump() for h in request.history]} if request.history else {}
|
|
|
|
if interrupt_ctx and interrupt_ctx.consumed:
|
|
base_context["interrupted_content"] = interrupt_ctx.interrupted_content
|
|
base_context["interrupted_segment_ids"] = interrupt_ctx.interrupted_segment_ids
|
|
logger.info(
|
|
f"[AC-MARH-03] Agent context enriched with interrupt: "
|
|
f"{len(interrupt_ctx.interrupted_content or '')} chars"
|
|
)
|
|
|
|
memory_context = ""
|
|
memory_missing_slots: list[str] = []
|
|
if session and request.user_id:
|
|
memory_recall_tool = MemoryRecallTool(
|
|
session=session,
|
|
timeout_governor=timeout_governor,
|
|
config=MemoryRecallConfig(
|
|
enabled=True,
|
|
timeout_ms=1000,
|
|
max_recent_messages=8,
|
|
),
|
|
)
|
|
|
|
memory_result = await memory_recall_tool.execute(
|
|
tenant_id=tenant_id,
|
|
user_id=request.user_id,
|
|
session_id=request.session_id,
|
|
)
|
|
|
|
if memory_result.duration_ms > 0:
|
|
memory_trace = memory_recall_tool.create_trace(memory_result, tenant_id)
|
|
trace_logger.update_trace(
|
|
request_id=request_id,
|
|
tool_calls=[memory_trace],
|
|
)
|
|
|
|
memory_context = memory_result.get_context_for_prompt()
|
|
memory_missing_slots = memory_result.missing_slots
|
|
|
|
if memory_context:
|
|
base_context["memory_context"] = memory_context
|
|
logger.info(
|
|
f"[AC-IDMP-13] Memory recall succeeded: "
|
|
f"profile={len(memory_result.profile)}, facts={len(memory_result.facts)}, "
|
|
f"slots={len(memory_result.slots)}, missing_slots={len(memory_missing_slots)}, "
|
|
f"duration_ms={memory_result.duration_ms}"
|
|
)
|
|
elif memory_result.fallback_reason_code:
|
|
logger.warning(
|
|
f"[AC-IDMP-13] Memory recall fallback: reason={memory_result.fallback_reason_code}"
|
|
)
|
|
|
|
kb_hits = []
|
|
kb_success = False
|
|
kb_fallback_reason = None
|
|
kb_applied_filter = {}
|
|
kb_missing_slots = []
|
|
|
|
if session and tool_registry:
|
|
kb_tool = KbSearchDynamicTool(
|
|
session=session,
|
|
timeout_governor=timeout_governor,
|
|
config=KbSearchDynamicConfig(
|
|
enabled=True,
|
|
top_k=5,
|
|
timeout_ms=2000,
|
|
min_score_threshold=0.5,
|
|
),
|
|
)
|
|
|
|
kb_dynamic_result = await kb_tool.execute(
|
|
query=request.user_message,
|
|
tenant_id=tenant_id,
|
|
scene="open_consult",
|
|
top_k=5,
|
|
context=base_context,
|
|
)
|
|
|
|
kb_success = kb_dynamic_result.success
|
|
kb_hits = kb_dynamic_result.hits
|
|
kb_fallback_reason = kb_dynamic_result.fallback_reason_code
|
|
kb_applied_filter = kb_dynamic_result.applied_filter
|
|
kb_missing_slots = kb_dynamic_result.missing_required_slots
|
|
|
|
if kb_dynamic_result.tool_trace:
|
|
trace_logger.update_trace(
|
|
request_id=request_id,
|
|
tool_calls=[kb_dynamic_result.tool_trace],
|
|
)
|
|
|
|
logger.info(
|
|
f"[AC-MARH-05] KB dynamic search: success={kb_success}, "
|
|
f"hits={len(kb_hits)}, filter={kb_applied_filter}, "
|
|
f"missing_slots={kb_missing_slots}"
|
|
)
|
|
else:
|
|
kb_result = await default_kb_tool_runner.execute(
|
|
tenant_id=tenant_id,
|
|
query=request.user_message,
|
|
)
|
|
kb_success = kb_result.success
|
|
kb_hits = kb_result.hits
|
|
kb_fallback_reason = kb_result.fallback_reason_code
|
|
|
|
runtime_observer.record_kb(
|
|
request_id,
|
|
tool_called=True,
|
|
hit=kb_success and len(kb_hits) > 0,
|
|
fallback_reason=kb_fallback_reason,
|
|
)
|
|
|
|
if kb_success and kb_hits:
|
|
kb_context = "\n".join([
|
|
f"[知识库] {hit.get('content', '')[:200]}"
|
|
for hit in kb_hits[:3]
|
|
])
|
|
base_context["kb_context"] = kb_context
|
|
logger.info(
|
|
f"[AC-MARH-05] KB retrieval succeeded: hits={len(kb_hits)}"
|
|
)
|
|
elif kb_fallback_reason:
|
|
logger.warning(
|
|
f"[AC-MARH-06] KB retrieval fallback: reason={kb_fallback_reason}"
|
|
)
|
|
|
|
orchestrator = AgentOrchestrator(
|
|
max_iterations=5,
|
|
timeout_governor=timeout_governor,
|
|
llm_client=llm_client,
|
|
tool_registry=tool_registry,
|
|
)
|
|
|
|
final_answer, react_ctx, agent_trace = await orchestrator.execute(
|
|
user_message=request.user_message,
|
|
context=base_context,
|
|
)
|
|
|
|
runtime_observer.record_react(request_id, react_ctx.iteration, react_ctx.tool_calls)
|
|
|
|
trace_logger.update_trace(
|
|
request_id=request_id,
|
|
react_iterations=react_ctx.iteration,
|
|
tool_calls=react_ctx.tool_calls,
|
|
)
|
|
|
|
segments = _text_to_segments(final_answer)
|
|
|
|
return DialogueResponse(
|
|
segments=segments,
|
|
trace=TraceInfo(
|
|
mode=ExecutionMode.AGENT,
|
|
request_id=trace.request_id,
|
|
generation_id=trace.generation_id,
|
|
react_iterations=react_ctx.iteration,
|
|
tools_used=[tc.tool_name for tc in react_ctx.tool_calls] if react_ctx.tool_calls else None,
|
|
tool_calls=react_ctx.tool_calls,
|
|
timeout_profile=timeout_governor.profile,
|
|
kb_tool_called=True,
|
|
kb_hit=kb_success and len(kb_hits) > 0,
|
|
fallback_reason_code=kb_fallback_reason,
|
|
),
|
|
)
|
|
|
|
|
|
async def _execute_micro_flow_mode(
|
|
tenant_id: str,
|
|
request: DialogueRequest,
|
|
router_result: Any,
|
|
trace: TraceInfo,
|
|
trace_logger: TraceLogger,
|
|
session: AsyncSession,
|
|
start_time: float,
|
|
) -> DialogueResponse:
|
|
"""Execute micro flow mode."""
|
|
if router_result.target_flow_id:
|
|
try:
|
|
from app.services.flow.engine import FlowEngine
|
|
|
|
flow_engine = FlowEngine(session)
|
|
instance, first_step = await flow_engine.start(
|
|
tenant_id=tenant_id,
|
|
session_id=request.session_id,
|
|
flow_id=router_result.target_flow_id,
|
|
)
|
|
|
|
if first_step:
|
|
return DialogueResponse(
|
|
segments=_text_to_segments(first_step),
|
|
trace=TraceInfo(
|
|
mode=ExecutionMode.MICRO_FLOW,
|
|
request_id=trace.request_id,
|
|
generation_id=trace.generation_id,
|
|
intent=router_result.intent,
|
|
),
|
|
)
|
|
except Exception as e:
|
|
logger.warning(f"[AC-IDMP-05] Micro flow start failed: {e}")
|
|
|
|
return DialogueResponse(
|
|
segments=[Segment(
|
|
text="正在为您处理,请稍候...",
|
|
delay_after=0,
|
|
)],
|
|
trace=TraceInfo(
|
|
mode=ExecutionMode.MICRO_FLOW,
|
|
request_id=trace.request_id,
|
|
generation_id=trace.generation_id,
|
|
intent=router_result.intent,
|
|
fallback_reason_code=router_result.fallback_reason_code,
|
|
),
|
|
)
|
|
|
|
|
|
async def _execute_fixed_mode(
|
|
tenant_id: str,
|
|
request: DialogueRequest,
|
|
router_result: Any,
|
|
trace: TraceInfo,
|
|
trace_logger: TraceLogger,
|
|
start_time: float,
|
|
) -> DialogueResponse:
|
|
"""Execute fixed reply mode."""
|
|
text = router_result.fixed_reply or "收到您的消息,我们会尽快处理。"
|
|
|
|
return DialogueResponse(
|
|
segments=_text_to_segments(text),
|
|
trace=TraceInfo(
|
|
mode=ExecutionMode.FIXED,
|
|
request_id=trace.request_id,
|
|
generation_id=trace.generation_id,
|
|
intent=router_result.intent,
|
|
fallback_reason_code=router_result.fallback_reason_code,
|
|
),
|
|
)
|
|
|
|
|
|
async def _execute_transfer_mode(
|
|
tenant_id: str,
|
|
request: DialogueRequest,
|
|
router_result: Any,
|
|
trace: TraceInfo,
|
|
trace_logger: TraceLogger,
|
|
start_time: float,
|
|
) -> DialogueResponse:
|
|
"""Execute transfer to human mode."""
|
|
text = router_result.transfer_message or "正在为您转接人工客服,请稍候..."
|
|
|
|
return DialogueResponse(
|
|
segments=_text_to_segments(text),
|
|
trace=TraceInfo(
|
|
mode=ExecutionMode.TRANSFER,
|
|
request_id=trace.request_id,
|
|
generation_id=trace.generation_id,
|
|
intent=router_result.intent,
|
|
fallback_reason_code=router_result.fallback_reason_code,
|
|
),
|
|
)
|
|
|
|
|
|
def _text_to_segments(text: str) -> list[Segment]:
|
|
"""Convert text to segments."""
|
|
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
|
|
|
|
if not paragraphs:
|
|
paragraphs = [text]
|
|
|
|
return [
|
|
Segment(text=p, delay_after=100 if i < len(paragraphs) - 1 else 0)
|
|
for i, p in enumerate(paragraphs)
|
|
]
|