ai-robot-core/ai-service/app/api/mid/dialogue.py

1077 lines
36 KiB
Python

"""
Dialogue Controller for Mid Platform.
[AC-MARH-01, AC-MARH-02, AC-MARH-03, AC-MARH-04, AC-MARH-05, AC-MARH-06,
AC-MARH-07, AC-MARH-08, AC-MARH-09, AC-MARH-10, AC-MARH-11, AC-MARH-12]
Core endpoint: POST /mid/dialogue/respond
"""
import logging
import time
import uuid
from typing import Annotated, Any
from fastapi import APIRouter, Depends, Request
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session
from app.core.tenant import get_tenant_id
from app.models.mid.schemas import (
DialogueRequest,
DialogueResponse,
ExecutionMode,
Segment,
TraceInfo,
)
from app.services.mid.agent_orchestrator import AgentOrchestrator
from app.services.mid.default_kb_tool_runner import DefaultKbToolRunner
from app.services.mid.feature_flags import FeatureFlagService
from app.services.mid.high_risk_handler import HighRiskHandler
from app.services.mid.interrupt_context_enricher import InterruptContextEnricher
from app.services.mid.kb_search_dynamic_tool import (
KbSearchDynamicConfig,
KbSearchDynamicTool,
)
from app.services.mid.high_risk_check_tool import (
HighRiskCheckConfig,
HighRiskCheckTool,
register_high_risk_check_tool,
)
from app.services.mid.intent_hint_tool import (
IntentHintConfig,
IntentHintTool,
register_intent_hint_tool,
)
from app.services.mid.memory_recall_tool import (
MemoryRecallConfig,
MemoryRecallTool,
register_memory_recall_tool,
)
from app.services.mid.metrics_collector import MetricsCollector
from app.services.mid.output_guardrail_executor import OutputGuardrailExecutor
from app.services.mid.policy_router import IntentMatch, PolicyRouter
from app.services.mid.runtime_observer import RuntimeObserver
from app.services.mid.segment_humanizer import HumanizeConfig, SegmentHumanizer
from app.services.mid.timeout_governor import TimeoutGovernor
from app.services.mid.tool_registry import ToolRegistry
from app.services.mid.trace_logger import TraceLogger
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/mid", tags=["Mid Platform Dialogue"])
_mid_services: dict[str, Any] = {}
def get_policy_router() -> PolicyRouter:
"""Get or create PolicyRouter instance."""
if "policy_router" not in _mid_services:
_mid_services["policy_router"] = PolicyRouter()
return _mid_services["policy_router"]
def get_high_risk_handler() -> HighRiskHandler:
"""Get or create HighRiskHandler instance."""
if "high_risk_handler" not in _mid_services:
_mid_services["high_risk_handler"] = HighRiskHandler()
return _mid_services["high_risk_handler"]
def get_timeout_governor() -> TimeoutGovernor:
"""Get or create TimeoutGovernor instance."""
if "timeout_governor" not in _mid_services:
_mid_services["timeout_governor"] = TimeoutGovernor()
return _mid_services["timeout_governor"]
def get_feature_flag_service() -> FeatureFlagService:
"""Get or create FeatureFlagService instance."""
if "feature_flag_service" not in _mid_services:
_mid_services["feature_flag_service"] = FeatureFlagService()
return _mid_services["feature_flag_service"]
def get_trace_logger() -> TraceLogger:
"""Get or create TraceLogger instance."""
if "trace_logger" not in _mid_services:
_mid_services["trace_logger"] = TraceLogger()
return _mid_services["trace_logger"]
def get_metrics_collector() -> MetricsCollector:
"""Get or create MetricsCollector instance."""
if "metrics_collector" not in _mid_services:
_mid_services["metrics_collector"] = MetricsCollector()
return _mid_services["metrics_collector"]
def get_tool_registry() -> ToolRegistry:
"""Get or create ToolRegistry instance."""
if "tool_registry" not in _mid_services:
_mid_services["tool_registry"] = ToolRegistry(
timeout_governor=get_timeout_governor()
)
return _mid_services["tool_registry"]
_kb_search_dynamic_registered: bool = False
_intent_hint_registered: bool = False
_high_risk_check_registered: bool = False
_memory_recall_registered: bool = False
def ensure_kb_search_dynamic_registered(
registry: ToolRegistry,
session: AsyncSession,
) -> None:
"""[AC-MARH-05] Ensure kb_search_dynamic tool is registered."""
global _kb_search_dynamic_registered
if _kb_search_dynamic_registered:
return
from app.services.mid.kb_search_dynamic_tool import register_kb_search_dynamic_tool
config = KbSearchDynamicConfig(
enabled=True,
top_k=5,
timeout_ms=2000,
min_score_threshold=0.5,
)
register_kb_search_dynamic_tool(
registry=registry,
session=session,
timeout_governor=get_timeout_governor(),
config=config,
)
_kb_search_dynamic_registered = True
logger.info("[AC-MARH-05] kb_search_dynamic tool registered to registry")
def ensure_intent_hint_registered(
registry: ToolRegistry,
session: AsyncSession,
) -> None:
"""[AC-IDMP-02, AC-IDMP-16] Ensure intent_hint tool is registered."""
global _intent_hint_registered
if _intent_hint_registered:
return
config = IntentHintConfig(
enabled=True,
timeout_ms=500,
top_n=3,
low_confidence_threshold=0.3,
)
register_intent_hint_tool(
registry=registry,
session=session,
config=config,
)
_intent_hint_registered = True
logger.info("[AC-IDMP-02] intent_hint tool registered to registry")
def ensure_high_risk_check_registered(
registry: ToolRegistry,
session: AsyncSession,
) -> None:
"""[AC-IDMP-05, AC-IDMP-20] Ensure high_risk_check tool is registered."""
global _high_risk_check_registered
if _high_risk_check_registered:
return
config = HighRiskCheckConfig(
enabled=True,
timeout_ms=500,
default_confidence=0.9,
)
register_high_risk_check_tool(
registry=registry,
session=session,
config=config,
)
_high_risk_check_registered = True
logger.info("[AC-IDMP-05] high_risk_check tool registered to registry")
def ensure_memory_recall_registered(
registry: ToolRegistry,
session: AsyncSession,
) -> None:
"""[AC-IDMP-13] Ensure memory_recall tool is registered."""
global _memory_recall_registered
if _memory_recall_registered:
return
config = MemoryRecallConfig(
enabled=True,
timeout_ms=1000,
max_recent_messages=8,
)
register_memory_recall_tool(
registry=registry,
session=session,
timeout_governor=get_timeout_governor(),
config=config,
)
_memory_recall_registered = True
logger.info("[AC-IDMP-13] memory_recall tool registered to registry")
def get_output_guardrail_executor() -> OutputGuardrailExecutor:
"""Get or create OutputGuardrailExecutor instance."""
if "output_guardrail_executor" not in _mid_services:
_mid_services["output_guardrail_executor"] = OutputGuardrailExecutor()
return _mid_services["output_guardrail_executor"]
def get_interrupt_context_enricher() -> InterruptContextEnricher:
"""Get or create InterruptContextEnricher instance."""
if "interrupt_context_enricher" not in _mid_services:
_mid_services["interrupt_context_enricher"] = InterruptContextEnricher()
return _mid_services["interrupt_context_enricher"]
def get_default_kb_tool_runner() -> DefaultKbToolRunner:
"""Get or create DefaultKbToolRunner instance."""
if "default_kb_tool_runner" not in _mid_services:
_mid_services["default_kb_tool_runner"] = DefaultKbToolRunner(
timeout_governor=get_timeout_governor()
)
return _mid_services["default_kb_tool_runner"]
def get_segment_humanizer() -> SegmentHumanizer:
"""Get or create SegmentHumanizer instance."""
if "segment_humanizer" not in _mid_services:
_mid_services["segment_humanizer"] = SegmentHumanizer()
return _mid_services["segment_humanizer"]
def get_runtime_observer() -> RuntimeObserver:
"""Get or create RuntimeObserver instance."""
if "runtime_observer" not in _mid_services:
_mid_services["runtime_observer"] = RuntimeObserver()
return _mid_services["runtime_observer"]
@router.post(
"/dialogue/respond",
operation_id="respondDialogue",
summary="Generate mid platform dialogue response",
description="""
[AC-MARH-01~12] Core dialogue response endpoint for mid platform.
Returns segments[] with trace info including:
- guardrail_triggered, guardrail_rule_id
- interrupt_consumed
- kb_tool_called, kb_hit
- timeout_profile, segment_stats
""",
)
async def respond_dialogue(
request: Request,
dialogue_request: DialogueRequest,
session: Annotated[AsyncSession, Depends(get_session)],
policy_router: PolicyRouter = Depends(get_policy_router),
high_risk_handler: HighRiskHandler = Depends(get_high_risk_handler),
timeout_governor: TimeoutGovernor = Depends(get_timeout_governor),
feature_flag_service: FeatureFlagService = Depends(get_feature_flag_service),
trace_logger: TraceLogger = Depends(get_trace_logger),
metrics_collector: MetricsCollector = Depends(get_metrics_collector),
output_guardrail_executor: OutputGuardrailExecutor = Depends(get_output_guardrail_executor),
interrupt_context_enricher: InterruptContextEnricher = Depends(get_interrupt_context_enricher),
default_kb_tool_runner: DefaultKbToolRunner = Depends(get_default_kb_tool_runner),
segment_humanizer: SegmentHumanizer = Depends(get_segment_humanizer),
runtime_observer: RuntimeObserver = Depends(get_runtime_observer),
) -> DialogueResponse:
"""
[AC-MARH-01~12] Generate dialogue response with segments and trace.
Flow:
1. Validate request and get tenant context
2. Start runtime observation
3. Process interrupted segments (AC-MARH-03/04)
4. Check feature flags for grayscale/rollback
5. Detect high-risk scenarios
6. Route to appropriate execution mode
7. For Agent mode: call KB tool (AC-MARH-05/06)
8. Execute output guardrail (AC-MARH-01/02)
9. Apply segment humanizer (AC-MARH-10/11)
10. Collect trace and return (AC-MARH-12)
"""
start_time = time.time()
tenant_id = get_tenant_id()
if not tenant_id:
from app.core.exceptions import MissingTenantIdException
raise MissingTenantIdException()
request_id = str(uuid.uuid4())
generation_id = str(uuid.uuid4())
logger.info(
f"[AC-MARH-01] Dialogue request: tenant={tenant_id}, "
f"session={dialogue_request.session_id}, request_id={request_id}"
)
runtime_ctx = runtime_observer.start_observation(
tenant_id=tenant_id,
session_id=dialogue_request.session_id,
request_id=request_id,
generation_id=generation_id,
)
trace = trace_logger.start_trace(
tenant_id=tenant_id,
session_id=dialogue_request.session_id,
request_id=request_id,
generation_id=generation_id,
)
metrics_collector.start_session(dialogue_request.session_id)
tool_registry = get_tool_registry()
ensure_kb_search_dynamic_registered(tool_registry, session)
ensure_intent_hint_registered(tool_registry, session)
ensure_high_risk_check_registered(tool_registry, session)
ensure_memory_recall_registered(tool_registry, session)
try:
interrupt_ctx = interrupt_context_enricher.enrich(
dialogue_request.interrupted_segments,
generation_id,
)
runtime_observer.record_interrupt(request_id, interrupt_ctx.consumed)
feature_flags = dialogue_request.feature_flags or feature_flag_service.get_flags(
dialogue_request.session_id
)
if feature_flags.rollback_to_legacy:
logger.info(f"[AC-MARH-17] Rollback to legacy for session: {dialogue_request.session_id}")
return await _handle_legacy_response(
tenant_id=tenant_id,
request=dialogue_request,
trace=trace,
trace_logger=trace_logger,
start_time=start_time,
)
high_risk_check_tool = HighRiskCheckTool(
session=session,
config=HighRiskCheckConfig(enabled=True, timeout_ms=500),
)
high_risk_result = await high_risk_check_tool.execute(
message=dialogue_request.user_message,
tenant_id=tenant_id,
)
if high_risk_result.duration_ms > 0:
hr_trace = high_risk_check_tool.create_trace(high_risk_result, tenant_id)
trace_logger.update_trace(
request_id=request_id,
tool_calls=[hr_trace],
)
logger.info(
f"[AC-IDMP-05, AC-IDMP-20] High risk check result: "
f"matched={high_risk_result.matched}, scenario={high_risk_result.risk_scenario}, "
f"duration_ms={high_risk_result.duration_ms}"
)
if high_risk_result.matched and high_risk_result.risk_scenario:
logger.info(
f"[AC-IDMP-05] High-risk matched from tool: {high_risk_result.risk_scenario.value}"
)
return await _handle_high_risk_check_response(
tenant_id=tenant_id,
request=dialogue_request,
high_risk_result=high_risk_result,
trace=trace,
trace_logger=trace_logger,
start_time=start_time,
session=session,
)
intent_hint_tool = IntentHintTool(
session=session,
config=IntentHintConfig(enabled=True, timeout_ms=500),
)
intent_hint = await intent_hint_tool.execute(
message=dialogue_request.user_message,
tenant_id=tenant_id,
history=[h.model_dump() for h in dialogue_request.history] if dialogue_request.history else None,
)
if intent_hint.duration_ms > 0:
hint_trace = intent_hint_tool.create_trace(intent_hint)
trace_logger.update_trace(
request_id=request_id,
tool_calls=[hint_trace],
)
logger.info(
f"[AC-IDMP-02] Intent hint result: intent={intent_hint.intent}, "
f"confidence={intent_hint.confidence}, suggested_mode={intent_hint.suggested_mode}"
)
intent_match = await _match_intent(tenant_id, dialogue_request, session)
router_result = policy_router.route(
user_message=dialogue_request.user_message,
session_mode="BOT_ACTIVE",
feature_flags=feature_flags,
intent_match=intent_match,
intent_hint=intent_hint,
)
runtime_observer.update_mode(request_id, router_result.mode, router_result.intent)
runtime_observer.record_timeout_profile(request_id, timeout_governor.profile)
trace_logger.update_trace(
request_id=request_id,
mode=router_result.mode,
intent=router_result.intent,
fallback_reason_code=router_result.fallback_reason_code,
)
if router_result.mode == ExecutionMode.AGENT:
response = await _execute_agent_mode(
tenant_id=tenant_id,
request=dialogue_request,
request_id=request_id,
trace=trace,
trace_logger=trace_logger,
timeout_governor=timeout_governor,
metrics_collector=metrics_collector,
default_kb_tool_runner=default_kb_tool_runner,
runtime_observer=runtime_observer,
interrupt_ctx=interrupt_ctx,
start_time=start_time,
session=session,
tool_registry=tool_registry,
)
elif router_result.mode == ExecutionMode.MICRO_FLOW:
response = await _execute_micro_flow_mode(
tenant_id=tenant_id,
request=dialogue_request,
router_result=router_result,
trace=trace,
trace_logger=trace_logger,
session=session,
start_time=start_time,
)
elif router_result.mode == ExecutionMode.FIXED:
response = await _execute_fixed_mode(
tenant_id=tenant_id,
request=dialogue_request,
router_result=router_result,
trace=trace,
trace_logger=trace_logger,
start_time=start_time,
)
else:
response = await _execute_transfer_mode(
tenant_id=tenant_id,
request=dialogue_request,
router_result=router_result,
trace=trace,
trace_logger=trace_logger,
start_time=start_time,
)
filtered_segments, guardrail_result = await output_guardrail_executor.filter_segments(
response.segments, tenant_id
)
runtime_observer.record_guardrail(
request_id, guardrail_result.triggered, guardrail_result.rule_id
)
humanize_config = None
if dialogue_request.humanize_config:
humanize_config = HumanizeConfig(
enabled=dialogue_request.humanize_config.enabled or True,
min_delay_ms=dialogue_request.humanize_config.min_delay_ms or 50,
max_delay_ms=dialogue_request.humanize_config.max_delay_ms or 500,
length_bucket_strategy=dialogue_request.humanize_config.length_bucket_strategy or "simple",
)
final_segments, segment_stats = segment_humanizer.humanize(
"\n\n".join(s.text for s in filtered_segments),
humanize_config,
)
runtime_observer.record_segment_stats(request_id, segment_stats)
final_trace = runtime_observer.end_observation(request_id)
final_trace.segment_stats = segment_stats
final_trace.guardrail_triggered = guardrail_result.triggered
final_trace.guardrail_rule_id = guardrail_result.rule_id
latency_ms = int((time.time() - start_time) * 1000)
metrics_collector.record_turn(
session_id=dialogue_request.session_id,
tenant_id=tenant_id,
latency_ms=latency_ms,
task_completed=True,
)
audit = trace_logger.end_trace(
request_id=request_id,
tenant_id=tenant_id,
session_id=dialogue_request.session_id,
latency_ms=latency_ms,
)
logger.info(
f"[AC-MARH-12] Audit record: request_id={request_id}, "
f"mode={final_trace.mode.value}, latency_ms={latency_ms}, "
f"guardrail={guardrail_result.triggered}, kb_hit={final_trace.kb_hit}"
)
return DialogueResponse(
segments=final_segments,
trace=final_trace,
)
except Exception as e:
latency_ms = int((time.time() - start_time) * 1000)
logger.error(f"[AC-IDMP-06] Dialogue error: {e}")
trace_logger.update_trace(
request_id=request_id,
mode=ExecutionMode.FIXED,
fallback_reason_code=f"error: {str(e)[:50]}",
)
trace_logger.end_trace(
request_id=request_id,
tenant_id=tenant_id,
session_id=dialogue_request.session_id,
latency_ms=latency_ms,
)
return DialogueResponse(
segments=[Segment(
text="抱歉,服务暂时不可用,请稍后重试或联系人工客服。",
delay_after=0,
)],
trace=TraceInfo(
mode=ExecutionMode.FIXED,
request_id=request_id,
generation_id=generation_id,
fallback_reason_code="service_error",
),
)
async def _match_intent(
tenant_id: str,
request: DialogueRequest,
session: AsyncSession,
) -> IntentMatch | None:
"""Match intent from user message."""
try:
from app.services.intent.router import IntentRouter
from app.services.intent.rule_service import IntentRuleService
rule_service = IntentRuleService(session)
rules = await rule_service.get_enabled_rules_for_matching(tenant_id)
if not rules:
return None
router = IntentRouter()
result = router.match(request.user_message, rules)
if result:
return IntentMatch(
intent_id=str(result.rule.id),
intent_name=result.rule.name,
confidence=0.8,
response_type=result.rule.response_type,
target_kb_ids=result.rule.target_kb_ids,
flow_id=str(result.rule.flow_id) if result.rule.flow_id else None,
fixed_reply=result.rule.fixed_reply,
transfer_message=result.rule.transfer_message,
)
return None
except Exception as e:
logger.warning(f"[AC-IDMP-02] Intent match failed: {e}")
return None
async def _handle_legacy_response(
tenant_id: str,
request: DialogueRequest,
trace: TraceInfo,
trace_logger: TraceLogger,
start_time: float,
) -> DialogueResponse:
"""Handle rollback to legacy pipeline."""
latency_ms = int((time.time() - start_time) * 1000)
return DialogueResponse(
segments=[Segment(
text="正在使用传统模式处理您的请求...",
delay_after=0,
)],
trace=TraceInfo(
mode=ExecutionMode.FIXED,
request_id=trace.request_id,
generation_id=trace.generation_id,
fallback_reason_code="rollback_to_legacy",
),
)
async def _handle_high_risk_check_response(
tenant_id: str,
request: DialogueRequest,
high_risk_result: Any,
trace: TraceInfo,
trace_logger: TraceLogger,
start_time: float,
session: AsyncSession,
) -> DialogueResponse:
"""
[AC-IDMP-05, AC-IDMP-20] Handle high-risk scenario from high_risk_check tool.
高风险优先于普通意图路由。
"""
from app.models.mid.schemas import HighRiskCheckResult
if not isinstance(high_risk_result, HighRiskCheckResult):
high_risk_result = HighRiskCheckResult(**high_risk_result)
latency_ms = int((time.time() - start_time) * 1000)
recommended_mode = high_risk_result.recommended_mode or ExecutionMode.MICRO_FLOW
risk_scenario = high_risk_result.risk_scenario
trace_logger.update_trace(
request_id=trace.request_id or "",
mode=recommended_mode,
fallback_reason_code=f"high_risk_{risk_scenario.value if risk_scenario else 'unknown'}",
)
if recommended_mode == ExecutionMode.TRANSFER:
transfer_msg = "正在为您转接人工客服..."
if risk_scenario:
if risk_scenario.value == "complaint_escalation":
transfer_msg = "检测到您可能需要投诉处理,正在为您转接人工客服..."
elif risk_scenario.value == "refund":
transfer_msg = "您的退款请求需要人工处理,正在为您转接..."
return DialogueResponse(
segments=[Segment(
text=transfer_msg,
delay_after=0,
)],
trace=TraceInfo(
mode=ExecutionMode.TRANSFER,
request_id=trace.request_id,
generation_id=trace.generation_id,
high_risk_policy_set=[risk_scenario] if risk_scenario else None,
fallback_reason_code=high_risk_result.rule_id,
),
)
if high_risk_result.rule_id:
try:
from sqlalchemy import select
from app.models.entities import HighRiskPolicy
import uuid
stmt = select(HighRiskPolicy).where(
HighRiskPolicy.id == uuid.UUID(high_risk_result.rule_id)
)
result = await session.execute(stmt)
policy = result.scalar_one_or_none()
if policy and policy.flow_id:
return DialogueResponse(
segments=[Segment(
text="检测到您的请求需要特殊处理,正在为您安排...",
delay_after=0,
)],
trace=TraceInfo(
mode=ExecutionMode.MICRO_FLOW,
request_id=trace.request_id,
generation_id=trace.generation_id,
high_risk_policy_set=[risk_scenario] if risk_scenario else None,
fallback_reason_code=high_risk_result.rule_id,
),
)
except Exception as e:
logger.warning(f"[AC-IDMP-05] Failed to load high risk policy: {e}")
return DialogueResponse(
segments=[Segment(
text="检测到您的请求需要特殊处理,正在为您安排...",
delay_after=0,
)],
trace=TraceInfo(
mode=ExecutionMode.MICRO_FLOW,
request_id=trace.request_id,
generation_id=trace.generation_id,
high_risk_policy_set=[risk_scenario] if risk_scenario else None,
fallback_reason_code=high_risk_result.rule_id or f"high_risk_{risk_scenario.value if risk_scenario else 'unknown'}",
),
)
async def _handle_high_risk_response(
tenant_id: str,
request: DialogueRequest,
high_risk_match: Any,
high_risk_handler: HighRiskHandler,
trace: TraceInfo,
trace_logger: TraceLogger,
start_time: float,
) -> DialogueResponse:
"""Handle high-risk scenario response."""
router_result = high_risk_handler.handle(high_risk_match)
latency_ms = int((time.time() - start_time) * 1000)
trace_logger.update_trace(
request_id=trace.request_id or "",
mode=router_result.mode,
fallback_reason_code=router_result.fallback_reason_code,
)
if router_result.mode == ExecutionMode.TRANSFER:
return DialogueResponse(
segments=[Segment(
text=router_result.transfer_message or "正在为您转接人工客服...",
delay_after=0,
)],
trace=TraceInfo(
mode=ExecutionMode.TRANSFER,
request_id=trace.request_id,
generation_id=trace.generation_id,
high_risk_policy_set=[high_risk_match.scenario],
fallback_reason_code=router_result.fallback_reason_code,
),
)
return DialogueResponse(
segments=[Segment(
text="检测到您的请求需要特殊处理,正在为您安排...",
delay_after=0,
)],
trace=TraceInfo(
mode=ExecutionMode.MICRO_FLOW,
request_id=trace.request_id,
generation_id=trace.generation_id,
high_risk_policy_set=[high_risk_match.scenario],
fallback_reason_code=router_result.fallback_reason_code,
),
)
async def _execute_agent_mode(
tenant_id: str,
request: DialogueRequest,
request_id: str,
trace: TraceInfo,
trace_logger: TraceLogger,
timeout_governor: TimeoutGovernor,
metrics_collector: MetricsCollector,
default_kb_tool_runner: DefaultKbToolRunner,
runtime_observer: RuntimeObserver,
interrupt_ctx: Any = None,
start_time: float = 0,
session: AsyncSession | None = None,
tool_registry: ToolRegistry | None = None,
) -> DialogueResponse:
"""[AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-IDMP-13] Execute agent mode with ReAct loop, KB tool, and memory recall."""
from app.services.llm.factory import get_llm_config_manager
try:
llm_manager = get_llm_config_manager()
llm_client = llm_manager.get_client()
except Exception as e:
logger.warning(f"[AC-MARH-07] Failed to get LLM client: {e}")
llm_client = None
base_context = {"history": [h.model_dump() for h in request.history]} if request.history else {}
if interrupt_ctx and interrupt_ctx.consumed:
base_context["interrupted_content"] = interrupt_ctx.interrupted_content
base_context["interrupted_segment_ids"] = interrupt_ctx.interrupted_segment_ids
logger.info(
f"[AC-MARH-03] Agent context enriched with interrupt: "
f"{len(interrupt_ctx.interrupted_content or '')} chars"
)
memory_context = ""
memory_missing_slots: list[str] = []
if session and request.user_id:
memory_recall_tool = MemoryRecallTool(
session=session,
timeout_governor=timeout_governor,
config=MemoryRecallConfig(
enabled=True,
timeout_ms=1000,
max_recent_messages=8,
),
)
memory_result = await memory_recall_tool.execute(
tenant_id=tenant_id,
user_id=request.user_id,
session_id=request.session_id,
)
if memory_result.duration_ms > 0:
memory_trace = memory_recall_tool.create_trace(memory_result, tenant_id)
trace_logger.update_trace(
request_id=request_id,
tool_calls=[memory_trace],
)
memory_context = memory_result.get_context_for_prompt()
memory_missing_slots = memory_result.missing_slots
if memory_context:
base_context["memory_context"] = memory_context
logger.info(
f"[AC-IDMP-13] Memory recall succeeded: "
f"profile={len(memory_result.profile)}, facts={len(memory_result.facts)}, "
f"slots={len(memory_result.slots)}, missing_slots={len(memory_missing_slots)}, "
f"duration_ms={memory_result.duration_ms}"
)
elif memory_result.fallback_reason_code:
logger.warning(
f"[AC-IDMP-13] Memory recall fallback: reason={memory_result.fallback_reason_code}"
)
kb_hits = []
kb_success = False
kb_fallback_reason = None
kb_applied_filter = {}
kb_missing_slots = []
if session and tool_registry:
kb_tool = KbSearchDynamicTool(
session=session,
timeout_governor=timeout_governor,
config=KbSearchDynamicConfig(
enabled=True,
top_k=5,
timeout_ms=2000,
min_score_threshold=0.5,
),
)
kb_dynamic_result = await kb_tool.execute(
query=request.user_message,
tenant_id=tenant_id,
scene="open_consult",
top_k=5,
context=base_context,
)
kb_success = kb_dynamic_result.success
kb_hits = kb_dynamic_result.hits
kb_fallback_reason = kb_dynamic_result.fallback_reason_code
kb_applied_filter = kb_dynamic_result.applied_filter
kb_missing_slots = kb_dynamic_result.missing_required_slots
if kb_dynamic_result.tool_trace:
trace_logger.update_trace(
request_id=request_id,
tool_calls=[kb_dynamic_result.tool_trace],
)
logger.info(
f"[AC-MARH-05] KB dynamic search: success={kb_success}, "
f"hits={len(kb_hits)}, filter={kb_applied_filter}, "
f"missing_slots={kb_missing_slots}"
)
else:
kb_result = await default_kb_tool_runner.execute(
tenant_id=tenant_id,
query=request.user_message,
)
kb_success = kb_result.success
kb_hits = kb_result.hits
kb_fallback_reason = kb_result.fallback_reason_code
runtime_observer.record_kb(
request_id,
tool_called=True,
hit=kb_success and len(kb_hits) > 0,
fallback_reason=kb_fallback_reason,
)
if kb_success and kb_hits:
kb_context = "\n".join([
f"[知识库] {hit.get('content', '')[:200]}"
for hit in kb_hits[:3]
])
base_context["kb_context"] = kb_context
logger.info(
f"[AC-MARH-05] KB retrieval succeeded: hits={len(kb_hits)}"
)
elif kb_fallback_reason:
logger.warning(
f"[AC-MARH-06] KB retrieval fallback: reason={kb_fallback_reason}"
)
orchestrator = AgentOrchestrator(
max_iterations=5,
timeout_governor=timeout_governor,
llm_client=llm_client,
tool_registry=tool_registry,
tenant_id=tenant_id,
)
final_answer, react_ctx, agent_trace = await orchestrator.execute(
user_message=request.user_message,
context=base_context,
)
runtime_observer.record_react(request_id, react_ctx.iteration, react_ctx.tool_calls)
trace_logger.update_trace(
request_id=request_id,
react_iterations=react_ctx.iteration,
tool_calls=react_ctx.tool_calls,
)
segments = _text_to_segments(final_answer)
return DialogueResponse(
segments=segments,
trace=TraceInfo(
mode=ExecutionMode.AGENT,
request_id=trace.request_id,
generation_id=trace.generation_id,
react_iterations=react_ctx.iteration,
tools_used=[tc.tool_name for tc in react_ctx.tool_calls] if react_ctx.tool_calls else None,
tool_calls=react_ctx.tool_calls,
timeout_profile=timeout_governor.profile,
kb_tool_called=True,
kb_hit=kb_success and len(kb_hits) > 0,
fallback_reason_code=kb_fallback_reason,
),
)
async def _execute_micro_flow_mode(
tenant_id: str,
request: DialogueRequest,
router_result: Any,
trace: TraceInfo,
trace_logger: TraceLogger,
session: AsyncSession,
start_time: float,
) -> DialogueResponse:
"""Execute micro flow mode."""
if router_result.target_flow_id:
try:
from app.services.flow.engine import FlowEngine
flow_engine = FlowEngine(session)
instance, first_step = await flow_engine.start(
tenant_id=tenant_id,
session_id=request.session_id,
flow_id=router_result.target_flow_id,
)
if first_step:
return DialogueResponse(
segments=_text_to_segments(first_step),
trace=TraceInfo(
mode=ExecutionMode.MICRO_FLOW,
request_id=trace.request_id,
generation_id=trace.generation_id,
intent=router_result.intent,
),
)
except Exception as e:
logger.warning(f"[AC-IDMP-05] Micro flow start failed: {e}")
return DialogueResponse(
segments=[Segment(
text="正在为您处理,请稍候...",
delay_after=0,
)],
trace=TraceInfo(
mode=ExecutionMode.MICRO_FLOW,
request_id=trace.request_id,
generation_id=trace.generation_id,
intent=router_result.intent,
fallback_reason_code=router_result.fallback_reason_code,
),
)
async def _execute_fixed_mode(
tenant_id: str,
request: DialogueRequest,
router_result: Any,
trace: TraceInfo,
trace_logger: TraceLogger,
start_time: float,
) -> DialogueResponse:
"""Execute fixed reply mode."""
text = router_result.fixed_reply or "收到您的消息,我们会尽快处理。"
return DialogueResponse(
segments=_text_to_segments(text),
trace=TraceInfo(
mode=ExecutionMode.FIXED,
request_id=trace.request_id,
generation_id=trace.generation_id,
intent=router_result.intent,
fallback_reason_code=router_result.fallback_reason_code,
),
)
async def _execute_transfer_mode(
tenant_id: str,
request: DialogueRequest,
router_result: Any,
trace: TraceInfo,
trace_logger: TraceLogger,
start_time: float,
) -> DialogueResponse:
"""Execute transfer to human mode."""
text = router_result.transfer_message or "正在为您转接人工客服,请稍候..."
return DialogueResponse(
segments=_text_to_segments(text),
trace=TraceInfo(
mode=ExecutionMode.TRANSFER,
request_id=trace.request_id,
generation_id=trace.generation_id,
intent=router_result.intent,
fallback_reason_code=router_result.fallback_reason_code,
),
)
def _text_to_segments(text: str) -> list[Segment]:
"""Convert text to segments."""
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
if not paragraphs:
paragraphs = [text]
return [
Segment(text=p, delay_after=100 if i < len(paragraphs) - 1 else 0)
for i, p in enumerate(paragraphs)
]