""" Dialogue Controller for Mid Platform. [AC-MARH-01, AC-MARH-02, AC-MARH-03, AC-MARH-04, AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-MARH-08, AC-MARH-09, AC-MARH-10, AC-MARH-11, AC-MARH-12] Core endpoint: POST /mid/dialogue/respond """ import logging import time import uuid from typing import Annotated, Any from fastapi import APIRouter, Depends, Request from sqlalchemy.ext.asyncio import AsyncSession from app.core.database import get_session from app.core.tenant import get_tenant_id from app.models.mid.schemas import ( DialogueRequest, DialogueResponse, ExecutionMode, Segment, TraceInfo, ) from app.services.mid.agent_orchestrator import AgentOrchestrator from app.services.mid.default_kb_tool_runner import DefaultKbToolRunner from app.services.mid.feature_flags import FeatureFlagService from app.services.mid.high_risk_handler import HighRiskHandler from app.services.mid.interrupt_context_enricher import InterruptContextEnricher from app.services.mid.kb_search_dynamic_tool import ( KbSearchDynamicConfig, KbSearchDynamicTool, ) from app.services.mid.high_risk_check_tool import ( HighRiskCheckConfig, HighRiskCheckTool, register_high_risk_check_tool, ) from app.services.mid.intent_hint_tool import ( IntentHintConfig, IntentHintTool, register_intent_hint_tool, ) from app.services.mid.memory_recall_tool import ( MemoryRecallConfig, MemoryRecallTool, register_memory_recall_tool, ) from app.services.mid.metrics_collector import MetricsCollector from app.services.mid.output_guardrail_executor import OutputGuardrailExecutor from app.services.mid.policy_router import IntentMatch, PolicyRouter from app.services.mid.runtime_observer import RuntimeObserver from app.services.mid.segment_humanizer import HumanizeConfig, SegmentHumanizer from app.services.mid.timeout_governor import TimeoutGovernor from app.services.mid.tool_registry import ToolRegistry from app.services.mid.trace_logger import TraceLogger from app.services.prompt.template_service import PromptTemplateService from app.services.prompt.variable_resolver import VariableResolver from app.services.intent.clarification import ( ClarificationEngine, ClarifyReason, ClarifySessionManager, ClarifyState, HybridIntentResult, IntentCandidate, T_HIGH, T_LOW, ) logger = logging.getLogger(__name__) router = APIRouter(prefix="/mid", tags=["Mid Platform Dialogue"]) _mid_services: dict[str, Any] = {} def get_policy_router() -> PolicyRouter: """Get or create PolicyRouter instance.""" if "policy_router" not in _mid_services: _mid_services["policy_router"] = PolicyRouter() return _mid_services["policy_router"] def get_high_risk_handler() -> HighRiskHandler: """Get or create HighRiskHandler instance.""" if "high_risk_handler" not in _mid_services: _mid_services["high_risk_handler"] = HighRiskHandler() return _mid_services["high_risk_handler"] def get_timeout_governor() -> TimeoutGovernor: """Get or create TimeoutGovernor instance.""" if "timeout_governor" not in _mid_services: _mid_services["timeout_governor"] = TimeoutGovernor() return _mid_services["timeout_governor"] def get_feature_flag_service() -> FeatureFlagService: """Get or create FeatureFlagService instance.""" if "feature_flag_service" not in _mid_services: _mid_services["feature_flag_service"] = FeatureFlagService() return _mid_services["feature_flag_service"] def get_trace_logger() -> TraceLogger: """Get or create TraceLogger instance.""" if "trace_logger" not in _mid_services: _mid_services["trace_logger"] = TraceLogger() return _mid_services["trace_logger"] def get_metrics_collector() -> MetricsCollector: """Get or create MetricsCollector instance.""" if "metrics_collector" not in _mid_services: _mid_services["metrics_collector"] = MetricsCollector() return _mid_services["metrics_collector"] def get_tool_registry() -> ToolRegistry: """Get or create ToolRegistry instance.""" if "tool_registry" not in _mid_services: _mid_services["tool_registry"] = ToolRegistry( timeout_governor=get_timeout_governor() ) return _mid_services["tool_registry"] _kb_search_dynamic_registered: bool = False _intent_hint_registered: bool = False _high_risk_check_registered: bool = False _memory_recall_registered: bool = False _metadata_discovery_registered: bool = False def ensure_kb_search_dynamic_registered( registry: ToolRegistry, session: AsyncSession, ) -> None: """[AC-MARH-05] Ensure kb_search_dynamic tool is registered.""" global _kb_search_dynamic_registered if _kb_search_dynamic_registered: return from app.services.mid.kb_search_dynamic_tool import register_kb_search_dynamic_tool config = KbSearchDynamicConfig( enabled=True, top_k=5, timeout_ms=10000, min_score_threshold=0.5, ) register_kb_search_dynamic_tool( registry=registry, session=session, timeout_governor=get_timeout_governor(), config=config, ) _kb_search_dynamic_registered = True logger.info("[AC-MARH-05] kb_search_dynamic tool registered to registry") def ensure_intent_hint_registered( registry: ToolRegistry, session: AsyncSession, ) -> None: """[AC-IDMP-02, AC-IDMP-16] Ensure intent_hint tool is registered.""" global _intent_hint_registered if _intent_hint_registered: return config = IntentHintConfig( enabled=True, timeout_ms=500, top_n=3, low_confidence_threshold=0.3, ) register_intent_hint_tool( registry=registry, session=session, config=config, ) _intent_hint_registered = True logger.info("[AC-IDMP-02] intent_hint tool registered to registry") def ensure_high_risk_check_registered( registry: ToolRegistry, session: AsyncSession, ) -> None: """[AC-IDMP-05, AC-IDMP-20] Ensure high_risk_check tool is registered.""" global _high_risk_check_registered if _high_risk_check_registered: return config = HighRiskCheckConfig( enabled=True, timeout_ms=500, default_confidence=0.9, ) register_high_risk_check_tool( registry=registry, session=session, config=config, ) _high_risk_check_registered = True logger.info("[AC-IDMP-05] high_risk_check tool registered to registry") def ensure_memory_recall_registered( registry: ToolRegistry, session: AsyncSession, ) -> None: """[AC-IDMP-13] Ensure memory_recall tool is registered.""" global _memory_recall_registered if _memory_recall_registered: return config = MemoryRecallConfig( enabled=True, timeout_ms=1000, max_recent_messages=8, ) register_memory_recall_tool( registry=registry, session=session, timeout_governor=get_timeout_governor(), config=config, ) _memory_recall_registered = True logger.info("[AC-IDMP-13] memory_recall tool registered to registry") def ensure_metadata_discovery_registered( registry: ToolRegistry, session: AsyncSession, ) -> None: """Ensure metadata_discovery tool is registered.""" global _metadata_discovery_registered if _metadata_discovery_registered: return from app.services.mid.metadata_discovery_tool import register_metadata_discovery_tool register_metadata_discovery_tool( registry=registry, session=session, timeout_governor=get_timeout_governor(), ) _metadata_discovery_registered = True logger.info("[MetadataDiscovery] metadata_discovery tool registered to registry") def get_output_guardrail_executor() -> OutputGuardrailExecutor: """Get or create OutputGuardrailExecutor instance.""" if "output_guardrail_executor" not in _mid_services: _mid_services["output_guardrail_executor"] = OutputGuardrailExecutor() return _mid_services["output_guardrail_executor"] def get_interrupt_context_enricher() -> InterruptContextEnricher: """Get or create InterruptContextEnricher instance.""" if "interrupt_context_enricher" not in _mid_services: _mid_services["interrupt_context_enricher"] = InterruptContextEnricher() return _mid_services["interrupt_context_enricher"] def get_default_kb_tool_runner() -> DefaultKbToolRunner: """Get or create DefaultKbToolRunner instance.""" if "default_kb_tool_runner" not in _mid_services: _mid_services["default_kb_tool_runner"] = DefaultKbToolRunner( timeout_governor=get_timeout_governor() ) return _mid_services["default_kb_tool_runner"] def get_segment_humanizer() -> SegmentHumanizer: """Get or create SegmentHumanizer instance.""" if "segment_humanizer" not in _mid_services: _mid_services["segment_humanizer"] = SegmentHumanizer() return _mid_services["segment_humanizer"] def get_runtime_observer() -> RuntimeObserver: """Get or create RuntimeObserver instance.""" if "runtime_observer" not in _mid_services: _mid_services["runtime_observer"] = RuntimeObserver() return _mid_services["runtime_observer"] def get_clarification_engine() -> ClarificationEngine: """Get or create ClarificationEngine instance.""" if "clarification_engine" not in _mid_services: _mid_services["clarification_engine"] = ClarificationEngine( t_high=T_HIGH, t_low=T_LOW, ) return _mid_services["clarification_engine"] @router.post( "/dialogue/respond", operation_id="respondDialogue", summary="Generate mid platform dialogue response", description=""" [AC-MARH-01~12] Core dialogue response endpoint for mid platform. Returns segments[] with trace info including: - guardrail_triggered, guardrail_rule_id - interrupt_consumed - kb_tool_called, kb_hit - timeout_profile, segment_stats """, ) async def respond_dialogue( request: Request, dialogue_request: DialogueRequest, session: Annotated[AsyncSession, Depends(get_session)], policy_router: PolicyRouter = Depends(get_policy_router), high_risk_handler: HighRiskHandler = Depends(get_high_risk_handler), timeout_governor: TimeoutGovernor = Depends(get_timeout_governor), feature_flag_service: FeatureFlagService = Depends(get_feature_flag_service), trace_logger: TraceLogger = Depends(get_trace_logger), metrics_collector: MetricsCollector = Depends(get_metrics_collector), output_guardrail_executor: OutputGuardrailExecutor = Depends(get_output_guardrail_executor), interrupt_context_enricher: InterruptContextEnricher = Depends(get_interrupt_context_enricher), default_kb_tool_runner: DefaultKbToolRunner = Depends(get_default_kb_tool_runner), segment_humanizer: SegmentHumanizer = Depends(get_segment_humanizer), runtime_observer: RuntimeObserver = Depends(get_runtime_observer), clarification_engine: ClarificationEngine = Depends(get_clarification_engine), ) -> DialogueResponse: """ [AC-MARH-01~12] Generate dialogue response with segments and trace. Flow: 1. Validate request and get tenant context 2. Start runtime observation 3. Process interrupted segments (AC-MARH-03/04) 4. Check feature flags for grayscale/rollback 5. Detect high-risk scenarios 6. Route to appropriate execution mode 7. For Agent mode: call KB tool (AC-MARH-05/06) 8. Execute output guardrail (AC-MARH-01/02) 9. Apply segment humanizer (AC-MARH-10/11) 10. Collect trace and return (AC-MARH-12) """ start_time = time.time() tenant_id = get_tenant_id() if not tenant_id: from app.core.exceptions import MissingTenantIdException raise MissingTenantIdException() request_id = str(uuid.uuid4()) generation_id = str(uuid.uuid4()) logger.info( f"[AC-MARH-01] Dialogue request: tenant={tenant_id}, " f"session={dialogue_request.session_id}, request_id={request_id}, " f"user_message={dialogue_request.user_message[:100] if dialogue_request.user_message else 'None'}..." ) runtime_ctx = runtime_observer.start_observation( tenant_id=tenant_id, session_id=dialogue_request.session_id, request_id=request_id, generation_id=generation_id, ) trace = trace_logger.start_trace( tenant_id=tenant_id, session_id=dialogue_request.session_id, request_id=request_id, generation_id=generation_id, ) metrics_collector.start_session(dialogue_request.session_id) tool_registry = get_tool_registry() ensure_kb_search_dynamic_registered(tool_registry, session) ensure_intent_hint_registered(tool_registry, session) ensure_high_risk_check_registered(tool_registry, session) ensure_memory_recall_registered(tool_registry, session) ensure_metadata_discovery_registered(tool_registry, session) try: interrupt_ctx = interrupt_context_enricher.enrich( dialogue_request.interrupted_segments, generation_id, ) runtime_observer.record_interrupt(request_id, interrupt_ctx.consumed) feature_flags = dialogue_request.feature_flags or feature_flag_service.get_flags( dialogue_request.session_id ) if feature_flags.rollback_to_legacy: logger.info(f"[AC-MARH-17] Rollback to legacy for session: {dialogue_request.session_id}") return await _handle_legacy_response( tenant_id=tenant_id, request=dialogue_request, trace=trace, trace_logger=trace_logger, start_time=start_time, ) high_risk_check_tool = HighRiskCheckTool( session=session, config=HighRiskCheckConfig(enabled=True, timeout_ms=500), ) high_risk_result = await high_risk_check_tool.execute( message=dialogue_request.user_message, tenant_id=tenant_id, ) if high_risk_result.duration_ms > 0: hr_trace = high_risk_check_tool.create_trace(high_risk_result, tenant_id) trace_logger.update_trace( request_id=request_id, tool_calls=[hr_trace], ) logger.info( f"[AC-IDMP-05, AC-IDMP-20] High risk check result: " f"matched={high_risk_result.matched}, scenario={high_risk_result.risk_scenario}, " f"duration_ms={high_risk_result.duration_ms}" ) if high_risk_result.matched and high_risk_result.risk_scenario: logger.info( f"[AC-IDMP-05] High-risk matched from tool: {high_risk_result.risk_scenario.value}" ) return await _handle_high_risk_check_response( tenant_id=tenant_id, request=dialogue_request, high_risk_result=high_risk_result, trace=trace, trace_logger=trace_logger, start_time=start_time, session=session, ) intent_hint_tool = IntentHintTool( session=session, config=IntentHintConfig(enabled=True, timeout_ms=500), ) intent_hint = await intent_hint_tool.execute( message=dialogue_request.user_message, tenant_id=tenant_id, history=[h.model_dump() for h in dialogue_request.history] if dialogue_request.history else None, ) if intent_hint.duration_ms > 0: hint_trace = intent_hint_tool.create_trace(intent_hint) trace_logger.update_trace( request_id=request_id, tool_calls=[hint_trace], ) logger.info( f"[AC-IDMP-02] Intent hint result: intent={intent_hint.intent}, " f"confidence={intent_hint.confidence}, suggested_mode={intent_hint.suggested_mode}" ) intent_match = await _match_intent(tenant_id, dialogue_request, session) router_result = policy_router.route( user_message=dialogue_request.user_message, session_mode="BOT_ACTIVE", feature_flags=feature_flags, intent_match=intent_match, intent_hint=intent_hint, ) runtime_observer.update_mode(request_id, router_result.mode, router_result.intent) runtime_observer.record_timeout_profile(request_id, timeout_governor.profile) trace_logger.update_trace( request_id=request_id, mode=router_result.mode, intent=router_result.intent, fallback_reason_code=router_result.fallback_reason_code, ) if router_result.mode == ExecutionMode.AGENT: response = await _execute_agent_mode( tenant_id=tenant_id, request=dialogue_request, request_id=request_id, trace=trace, trace_logger=trace_logger, timeout_governor=timeout_governor, metrics_collector=metrics_collector, default_kb_tool_runner=default_kb_tool_runner, runtime_observer=runtime_observer, interrupt_ctx=interrupt_ctx, start_time=start_time, session=session, tool_registry=tool_registry, ) elif router_result.mode == ExecutionMode.MICRO_FLOW: response = await _execute_micro_flow_mode( tenant_id=tenant_id, request=dialogue_request, router_result=router_result, trace=trace, trace_logger=trace_logger, session=session, start_time=start_time, ) elif router_result.mode == ExecutionMode.FIXED: response = await _execute_fixed_mode( tenant_id=tenant_id, request=dialogue_request, router_result=router_result, trace=trace, trace_logger=trace_logger, start_time=start_time, ) else: response = await _execute_transfer_mode( tenant_id=tenant_id, request=dialogue_request, router_result=router_result, trace=trace, trace_logger=trace_logger, start_time=start_time, ) filtered_segments, guardrail_result = await output_guardrail_executor.filter_segments( response.segments, tenant_id ) runtime_observer.record_guardrail( request_id, guardrail_result.triggered, guardrail_result.rule_id ) humanize_config = None if dialogue_request.humanize_config: humanize_config = HumanizeConfig( enabled=dialogue_request.humanize_config.enabled or True, min_delay_ms=dialogue_request.humanize_config.min_delay_ms or 50, max_delay_ms=dialogue_request.humanize_config.max_delay_ms or 500, length_bucket_strategy=dialogue_request.humanize_config.length_bucket_strategy or "simple", ) final_segments, segment_stats = segment_humanizer.humanize( "\n\n".join(s.text for s in filtered_segments), humanize_config, ) runtime_observer.record_segment_stats(request_id, segment_stats) final_trace = runtime_observer.end_observation(request_id) final_trace.segment_stats = segment_stats final_trace.guardrail_triggered = guardrail_result.triggered final_trace.guardrail_rule_id = guardrail_result.rule_id latency_ms = int((time.time() - start_time) * 1000) metrics_collector.record_turn( session_id=dialogue_request.session_id, tenant_id=tenant_id, latency_ms=latency_ms, task_completed=True, ) audit = trace_logger.end_trace( request_id=request_id, tenant_id=tenant_id, session_id=dialogue_request.session_id, latency_ms=latency_ms, ) logger.info( f"[AC-MARH-12] Audit record: request_id={request_id}, " f"mode={final_trace.mode.value}, latency_ms={latency_ms}, " f"guardrail={guardrail_result.triggered}, kb_hit={final_trace.kb_hit}" ) return DialogueResponse( segments=final_segments, trace=final_trace, ) except Exception as e: latency_ms = int((time.time() - start_time) * 1000) import traceback logger.error( f"[AC-IDMP-06] Dialogue error: {e}\n" f"Exception type: {type(e).__name__}\n" f"Request details: session_id={dialogue_request.session_id}, " f"user_message={dialogue_request.user_message[:200] if dialogue_request.user_message else 'None'}, " f"scene={dialogue_request.scene}, user_id={dialogue_request.user_id}\n" f"Traceback:\n{traceback.format_exc()}" ) trace_logger.update_trace( request_id=request_id, mode=ExecutionMode.FIXED, fallback_reason_code=f"error: {str(e)[:50]}", ) trace_logger.end_trace( request_id=request_id, tenant_id=tenant_id, session_id=dialogue_request.session_id, latency_ms=latency_ms, ) return DialogueResponse( segments=[Segment( text="抱歉,服务暂时不可用,请稍后重试或联系人工客服。", delay_after=0, )], trace=TraceInfo( mode=ExecutionMode.FIXED, request_id=request_id, generation_id=generation_id, fallback_reason_code="service_error", ), ) async def _match_intent( tenant_id: str, request: DialogueRequest, session: AsyncSession, ) -> IntentMatch | None: """Match intent from user message.""" try: from app.services.intent.router import IntentRouter from app.services.intent.rule_service import IntentRuleService rule_service = IntentRuleService(session) rules = await rule_service.get_enabled_rules_for_matching(tenant_id) if not rules: return None router = IntentRouter() result = router.match(request.user_message, rules) if result: return IntentMatch( intent_id=str(result.rule.id), intent_name=result.rule.name, confidence=1.0, response_type=result.rule.response_type, target_kb_ids=result.rule.target_kb_ids, flow_id=str(result.rule.flow_id) if result.rule.flow_id else None, fixed_reply=result.rule.fixed_reply, transfer_message=result.rule.transfer_message, ) return None except Exception as e: logger.warning(f"[AC-IDMP-02] Intent match failed: {e}") return None async def _match_intent_hybrid( tenant_id: str, request: DialogueRequest, session: AsyncSession, clarification_engine: ClarificationEngine, ) -> HybridIntentResult: """ [AC-CLARIFY] Hybrid intent matching with clarification support. Returns HybridIntentResult with unified confidence and candidates. """ try: from app.services.intent.router import IntentRouter from app.services.intent.rule_service import IntentRuleService rule_service = IntentRuleService(session) rules = await rule_service.get_enabled_rules_for_matching(tenant_id) if not rules: return HybridIntentResult( intent=None, confidence=0.0, candidates=[], ) router = IntentRouter() try: fusion_result = await router.match_hybrid( message=request.user_message, rules=rules, tenant_id=tenant_id, ) hybrid_result = HybridIntentResult.from_fusion_result(fusion_result) logger.info( f"[AC-CLARIFY] Hybrid intent match: " f"intent={hybrid_result.intent.intent_name if hybrid_result.intent else None}, " f"confidence={hybrid_result.confidence:.3f}, " f"need_clarify={hybrid_result.need_clarify}" ) return hybrid_result except Exception as e: logger.warning(f"[AC-CLARIFY] Hybrid match failed, fallback to rule: {e}") result = router.match(request.user_message, rules) if result: confidence = clarification_engine.compute_confidence( rule_score=1.0, semantic_score=0.0, llm_score=0.0, ) candidate = IntentCandidate( intent_id=str(result.rule.id), intent_name=result.rule.name, confidence=confidence, response_type=result.rule.response_type, target_kb_ids=result.rule.target_kb_ids, flow_id=str(result.rule.flow_id) if result.rule.flow_id else None, fixed_reply=result.rule.fixed_reply, transfer_message=result.rule.transfer_message, ) return HybridIntentResult( intent=candidate, confidence=confidence, candidates=[candidate], ) return HybridIntentResult( intent=None, confidence=0.0, candidates=[], ) except Exception as e: logger.warning(f"[AC-CLARIFY] Intent match failed: {e}") return HybridIntentResult( intent=None, confidence=0.0, candidates=[], ) async def _handle_clarification( tenant_id: str, request: DialogueRequest, hybrid_result: HybridIntentResult, clarification_engine: ClarificationEngine, trace: TraceInfo, session: AsyncSession, required_slots: list[str] | None = None, filled_slots: dict[str, Any] | None = None, ) -> DialogueResponse | None: """ [AC-CLARIFY] Handle clarification logic. Returns DialogueResponse if clarification is needed, None otherwise. """ existing_state = ClarifySessionManager.get_session(request.session_id) if existing_state and not existing_state.is_max_retry(): logger.info( f"[AC-CLARIFY] Processing clarify response: " f"session={request.session_id}, retry={existing_state.retry_count}" ) new_result = clarification_engine.process_clarify_response( user_message=request.user_message, state=existing_state, ) if not new_result.need_clarify: ClarifySessionManager.clear_session(request.session_id) if new_result.intent: return None clarify_prompt = clarification_engine.generate_clarify_prompt(existing_state) return DialogueResponse( segments=[Segment(text=clarify_prompt, delay_after=0)], trace=TraceInfo( mode=ExecutionMode.FIXED, request_id=trace.request_id, generation_id=trace.generation_id, fallback_reason_code=f"clarify_retry_{existing_state.retry_count}", intent=existing_state.candidates[0].intent_name if existing_state.candidates else None, ), ) should_clarify, clarify_state = clarification_engine.should_trigger_clarify( result=hybrid_result, required_slots=required_slots, filled_slots=filled_slots, ) if not should_clarify or not clarify_state: return None logger.info( f"[AC-CLARIFY] Clarification triggered: " f"session={request.session_id}, reason={clarify_state.reason}, " f"confidence={hybrid_result.confidence:.3f}" ) ClarifySessionManager.set_session(request.session_id, clarify_state) clarify_prompt = clarification_engine.generate_clarify_prompt(clarify_state) return DialogueResponse( segments=[Segment(text=clarify_prompt, delay_after=0)], trace=TraceInfo( mode=ExecutionMode.FIXED, request_id=trace.request_id, generation_id=trace.generation_id, fallback_reason_code=f"clarify_{clarify_state.reason.value}", intent=hybrid_result.intent.intent_name if hybrid_result.intent else None, ), ) async def _check_hard_block( tenant_id: str, request: DialogueRequest, hybrid_result: HybridIntentResult, clarification_engine: ClarificationEngine, trace: TraceInfo, required_slots: list[str] | None = None, filled_slots: dict[str, Any] | None = None, ) -> DialogueResponse | None: """ [AC-CLARIFY] Check hard block conditions. Hard blocks: 1. confidence < T_high: block entering new flow 2. required_slots missing: block flow progression Returns DialogueResponse if blocked, None otherwise. """ is_blocked, block_reason = clarification_engine.check_hard_block( result=hybrid_result, required_slots=required_slots, filled_slots=filled_slots, ) if not is_blocked: return None logger.info( f"[AC-CLARIFY] Hard block triggered: " f"session={request.session_id}, reason={block_reason}, " f"confidence={hybrid_result.confidence:.3f}" ) clarify_state = ClarifyState( reason=block_reason, asked_slot=required_slots[0] if required_slots and block_reason == ClarifyReason.MISSING_SLOT else None, candidates=hybrid_result.candidates, ) ClarifySessionManager.set_session(request.session_id, clarify_state) clarification_engine._metrics.record_clarify_trigger() clarify_prompt = clarification_engine.generate_clarify_prompt(clarify_state) return DialogueResponse( segments=[Segment(text=clarify_prompt, delay_after=0)], trace=TraceInfo( mode=ExecutionMode.FIXED, request_id=trace.request_id, generation_id=trace.generation_id, fallback_reason_code=f"hard_block_{block_reason.value}", intent=hybrid_result.intent.intent_name if hybrid_result.intent else None, ), ) async def _handle_legacy_response( tenant_id: str, request: DialogueRequest, trace: TraceInfo, trace_logger: TraceLogger, start_time: float, ) -> DialogueResponse: """Handle rollback to legacy pipeline.""" latency_ms = int((time.time() - start_time) * 1000) return DialogueResponse( segments=[Segment( text="正在使用传统模式处理您的请求...", delay_after=0, )], trace=TraceInfo( mode=ExecutionMode.FIXED, request_id=trace.request_id, generation_id=trace.generation_id, fallback_reason_code="rollback_to_legacy", ), ) async def _handle_high_risk_check_response( tenant_id: str, request: DialogueRequest, high_risk_result: Any, trace: TraceInfo, trace_logger: TraceLogger, start_time: float, session: AsyncSession, ) -> DialogueResponse: """ [AC-IDMP-05, AC-IDMP-20] Handle high-risk scenario from high_risk_check tool. 高风险优先于普通意图路由。 """ from app.models.mid.schemas import HighRiskCheckResult if not isinstance(high_risk_result, HighRiskCheckResult): high_risk_result = HighRiskCheckResult(**high_risk_result) latency_ms = int((time.time() - start_time) * 1000) recommended_mode = high_risk_result.recommended_mode or ExecutionMode.MICRO_FLOW risk_scenario = high_risk_result.risk_scenario trace_logger.update_trace( request_id=trace.request_id or "", mode=recommended_mode, fallback_reason_code=f"high_risk_{risk_scenario.value if risk_scenario else 'unknown'}", ) if recommended_mode == ExecutionMode.TRANSFER: transfer_msg = "正在为您转接人工客服..." if risk_scenario: if risk_scenario.value == "complaint_escalation": transfer_msg = "检测到您可能需要投诉处理,正在为您转接人工客服..." elif risk_scenario.value == "refund": transfer_msg = "您的退款请求需要人工处理,正在为您转接..." return DialogueResponse( segments=[Segment( text=transfer_msg, delay_after=0, )], trace=TraceInfo( mode=ExecutionMode.TRANSFER, request_id=trace.request_id, generation_id=trace.generation_id, high_risk_policy_set=[risk_scenario] if risk_scenario else None, fallback_reason_code=high_risk_result.rule_id, ), ) if high_risk_result.rule_id: try: from sqlalchemy import select from app.models.entities import HighRiskPolicy import uuid stmt = select(HighRiskPolicy).where( HighRiskPolicy.id == uuid.UUID(high_risk_result.rule_id) ) result = await session.execute(stmt) policy = result.scalar_one_or_none() if policy and policy.flow_id: return DialogueResponse( segments=[Segment( text="检测到您的请求需要特殊处理,正在为您安排...", delay_after=0, )], trace=TraceInfo( mode=ExecutionMode.MICRO_FLOW, request_id=trace.request_id, generation_id=trace.generation_id, high_risk_policy_set=[risk_scenario] if risk_scenario else None, fallback_reason_code=high_risk_result.rule_id, ), ) except Exception as e: logger.warning(f"[AC-IDMP-05] Failed to load high risk policy: {e}") return DialogueResponse( segments=[Segment( text="检测到您的请求需要特殊处理,正在为您安排...", delay_after=0, )], trace=TraceInfo( mode=ExecutionMode.MICRO_FLOW, request_id=trace.request_id, generation_id=trace.generation_id, high_risk_policy_set=[risk_scenario] if risk_scenario else None, fallback_reason_code=high_risk_result.rule_id or f"high_risk_{risk_scenario.value if risk_scenario else 'unknown'}", ), ) async def _handle_high_risk_response( tenant_id: str, request: DialogueRequest, high_risk_match: Any, high_risk_handler: HighRiskHandler, trace: TraceInfo, trace_logger: TraceLogger, start_time: float, ) -> DialogueResponse: """Handle high-risk scenario response.""" router_result = high_risk_handler.handle(high_risk_match) latency_ms = int((time.time() - start_time) * 1000) trace_logger.update_trace( request_id=trace.request_id or "", mode=router_result.mode, fallback_reason_code=router_result.fallback_reason_code, ) if router_result.mode == ExecutionMode.TRANSFER: return DialogueResponse( segments=[Segment( text=router_result.transfer_message or "正在为您转接人工客服...", delay_after=0, )], trace=TraceInfo( mode=ExecutionMode.TRANSFER, request_id=trace.request_id, generation_id=trace.generation_id, high_risk_policy_set=[high_risk_match.scenario], fallback_reason_code=router_result.fallback_reason_code, ), ) return DialogueResponse( segments=[Segment( text="检测到您的请求需要特殊处理,正在为您安排...", delay_after=0, )], trace=TraceInfo( mode=ExecutionMode.MICRO_FLOW, request_id=trace.request_id, generation_id=trace.generation_id, high_risk_policy_set=[high_risk_match.scenario], fallback_reason_code=router_result.fallback_reason_code, ), ) async def _execute_agent_mode( tenant_id: str, request: DialogueRequest, request_id: str, trace: TraceInfo, trace_logger: TraceLogger, timeout_governor: TimeoutGovernor, metrics_collector: MetricsCollector, default_kb_tool_runner: DefaultKbToolRunner, runtime_observer: RuntimeObserver, interrupt_ctx: Any = None, start_time: float = 0, session: AsyncSession | None = None, tool_registry: ToolRegistry | None = None, ) -> DialogueResponse: """ [AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-IDMP-13, AC-MRS-SLOT-META-03] Execute agent mode with ReAct loop, KB tool, memory recall, and slot state aggregation. """ from app.services.llm.factory import get_llm_config_manager from app.services.mid.slot_state_aggregator import SlotStateAggregator logger.info( f"[DEBUG-AGENT] Starting _execute_agent_mode: tenant={tenant_id}, " f"session={request.session_id}, scene={request.scene}, user_id={request.user_id}, " f"user_message={request.user_message[:100] if request.user_message else 'None'}..." ) try: llm_manager = get_llm_config_manager() llm_client = llm_manager.get_client() logger.info(f"[DEBUG-AGENT] LLM client obtained successfully") except Exception as e: logger.warning(f"[AC-MARH-07] Failed to get LLM client: {e}") llm_client = None base_context = {"history": [h.model_dump() for h in request.history]} if request.history else {} if request.scene: base_context["scene"] = request.scene if interrupt_ctx and interrupt_ctx.consumed: base_context["interrupted_content"] = interrupt_ctx.interrupted_content base_context["interrupted_segment_ids"] = interrupt_ctx.interrupted_segment_ids logger.info( f"[AC-MARH-03] Agent context enriched with interrupt: " f"{len(interrupt_ctx.interrupted_content or '')} chars" ) # [AC-MRS-SLOT-META-03] 初始化槽位状态聚合器 slot_state = None memory_context = "" memory_missing_slots: list[str] = [] scene_slot_context = None flow_status = None is_flow_runtime = False if session: from app.services.flow.engine import FlowEngine flow_engine = FlowEngine(session) flow_status = await flow_engine.get_flow_status(tenant_id, request.session_id) is_flow_runtime = bool(flow_status and flow_status.get("status") == "active") logger.info( f"[AC-IDMP-02] 运行轨道判定: session={request.session_id}, " f"is_flow_runtime={is_flow_runtime}, flow_id={flow_status.get('flow_id') if flow_status else None}, " f"current_step={flow_status.get('current_step') if flow_status else None}" ) # [AC-SCENE-SLOT-02] 加载场景槽位包(仅流程态加载) if is_flow_runtime and session and request.scene: from app.services.mid.scene_slot_bundle_loader import SceneSlotBundleLoader scene_loader = SceneSlotBundleLoader(session) scene_slot_context = await scene_loader.load_scene_context( tenant_id=tenant_id, scene_key=request.scene, ) if scene_slot_context: logger.info( f"[AC-SCENE-SLOT-02] 场景槽位包加载成功: " f"scene={request.scene}, required={len(scene_slot_context.required_slots)}, " f"optional={len(scene_slot_context.optional_slots)}, " f"threshold={scene_slot_context.completion_threshold}" ) base_context["scene_slot_context"] = { "scene_key": scene_slot_context.scene_key, "scene_name": scene_slot_context.scene_name, "required_slots": scene_slot_context.get_required_slot_keys(), "optional_slots": scene_slot_context.get_optional_slot_keys(), } if session and request.user_id: memory_recall_tool = MemoryRecallTool( session=session, timeout_governor=timeout_governor, config=MemoryRecallConfig( enabled=True, timeout_ms=1000, max_recent_messages=8, ), ) memory_result = await memory_recall_tool.execute( tenant_id=tenant_id, user_id=request.user_id, session_id=request.session_id, ) if memory_result.duration_ms > 0: memory_trace = memory_recall_tool.create_trace(memory_result, tenant_id) trace_logger.update_trace( request_id=request_id, tool_calls=[memory_trace], ) memory_context = memory_result.get_context_for_prompt() memory_missing_slots = memory_result.missing_slots if memory_context: base_context["memory_context"] = memory_context logger.info( f"[AC-IDMP-13] Memory recall succeeded: " f"profile={len(memory_result.profile)}, facts={len(memory_result.facts)}, " f"slots={len(memory_result.slots)}, missing_slots={len(memory_missing_slots)}, " f"duration_ms={memory_result.duration_ms}" ) elif memory_result.fallback_reason_code: logger.warning( f"[AC-IDMP-13] Memory recall fallback: reason={memory_result.fallback_reason_code}" ) # [AC-MRS-SLOT-META-03] 仅流程态执行槽位聚合/提取 if is_flow_runtime: slot_aggregator = SlotStateAggregator( session=session, tenant_id=tenant_id, session_id=request.session_id, ) slot_state = await slot_aggregator.aggregate( memory_slots=memory_result.slots, current_input_slots=None, # 可从 request 中解析 context=base_context, scene_slot_context=scene_slot_context, # [AC-SCENE-SLOT-02] 传入场景槽位上下文 ) logger.info( f"[AC-MRS-SLOT-META-03] 流程态槽位聚合完成: " f"filled={len(slot_state.filled_slots)}, " f"missing={len(slot_state.missing_required_slots)}, " f"mappings={slot_state.slot_to_field_map}" ) # [AC-MRS-SLOT-EXTRACT-01] 自动提取槽位 if slot_state and slot_state.missing_required_slots: from app.services.mid.slot_extraction_integration import SlotExtractionIntegration extraction_integration = SlotExtractionIntegration( session=session, tenant_id=tenant_id, session_id=request.session_id, ) extraction_result = await extraction_integration.extract_and_fill( user_input=request.user_message, slot_state=slot_state, ) if extraction_result.extracted_slots: slot_state = await slot_aggregator.aggregate( memory_slots=memory_result.slots, current_input_slots=extraction_result.extracted_slots, context=base_context, ) logger.info( f"[AC-MRS-SLOT-EXTRACT-01] 流程态自动提槽完成: " f"extracted={list(extraction_result.extracted_slots.keys())}, " f"time_ms={extraction_result.total_execution_time_ms:.2f}" ) else: logger.info( f"[AC-MRS-SLOT-META-03] 当前为通用问答轨道,跳过槽位聚合与缺槽追问: " f"session={request.session_id}" ) kb_hits = [] kb_success = False kb_fallback_reason = None kb_applied_filter = {} kb_missing_slots = [] kb_dynamic_result = None step_kb_binding_trace: dict[str, Any] | None = None # [Step-KB-Binding] 获取当前流程步骤的 KB 配置(仅流程态) step_kb_config = None if is_flow_runtime and session and flow_status: current_step_no = flow_status.get("current_step") flow_id = flow_status.get("flow_id") if flow_id and current_step_no: from app.models.entities import ScriptFlow from sqlalchemy import select stmt = select(ScriptFlow).where(ScriptFlow.id == flow_id) result = await session.execute(stmt) flow = result.scalar_one_or_none() if flow and flow.steps: step_idx = current_step_no - 1 if 0 <= step_idx < len(flow.steps): current_step = flow.steps[step_idx] # 构建 StepKbConfig from app.services.mid.kb_search_dynamic_tool import StepKbConfig step_kb_config = StepKbConfig( allowed_kb_ids=current_step.get("allowed_kb_ids"), preferred_kb_ids=current_step.get("preferred_kb_ids"), kb_query_hint=current_step.get("kb_query_hint"), max_kb_calls=current_step.get("max_kb_calls_per_step", 1), step_id=f"{flow_id}_step_{current_step_no}", ) step_kb_binding_trace = { "flow_id": str(flow_id), "flow_name": flow_status.get("flow_name"), "current_step": current_step_no, "step_id": step_kb_config.step_id, "allowed_kb_ids": step_kb_config.allowed_kb_ids, "preferred_kb_ids": step_kb_config.preferred_kb_ids, } logger.info( f"[Step-KB-Binding] 步骤知识库配置加载成功: " f"flow={flow_status.get('flow_name')}, step={current_step_no}, " f"allowed_kb_ids={step_kb_config.allowed_kb_ids}" ) if session and tool_registry: kb_tool = KbSearchDynamicTool( session=session, timeout_governor=timeout_governor, config=KbSearchDynamicConfig( enabled=True, top_k=5, timeout_ms=10000, min_score_threshold=0.5, ), ) # [AC-MRS-SLOT-META-03] 传入 slot_state 进行 KB 检索 # [Step-KB-Binding] 传入 step_kb_config 进行步骤级别的 KB 约束 kb_dynamic_result = await kb_tool.execute( query=request.user_message, tenant_id=tenant_id, top_k=5, context=base_context, slot_state=slot_state, step_kb_config=step_kb_config, slot_policy="flow_strict" if is_flow_runtime else "agent_relaxed", ) kb_success = kb_dynamic_result.success kb_hits = kb_dynamic_result.hits kb_fallback_reason = kb_dynamic_result.fallback_reason_code kb_applied_filter = kb_dynamic_result.applied_filter kb_missing_slots = kb_dynamic_result.missing_required_slots if kb_dynamic_result.tool_trace: trace_logger.update_trace( request_id=request_id, tool_calls=[kb_dynamic_result.tool_trace], ) logger.info( f"[AC-MARH-05] KB动态检索完成: success={kb_success}, " f"hits={len(kb_hits)}, filter={kb_applied_filter}, " f"missing_slots={kb_missing_slots}, track={'flow' if is_flow_runtime else 'agent'}" ) # [AC-MRS-SLOT-META-03] 处理缺失必填槽位 -> 追问闭环(仅流程态) if is_flow_runtime and kb_fallback_reason == "MISSING_REQUIRED_SLOTS" and kb_missing_slots: ask_back_text = await _generate_ask_back_for_missing_slots( slot_state=slot_state, missing_slots=kb_missing_slots, session=session, tenant_id=tenant_id, session_id=request.session_id, scene_slot_context=scene_slot_context, # [AC-SCENE-SLOT-02] 传入场景槽位上下文 ) logger.info( f"[AC-MRS-SLOT-META-03] 流程态缺槽追问文案已生成: " f"{ask_back_text[:50]}..." ) return DialogueResponse( segments=[Segment(text=ask_back_text, delay_after=0)], trace=TraceInfo( mode=ExecutionMode.AGENT, request_id=trace.request_id, generation_id=trace.generation_id, fallback_reason_code="missing_required_slots", kb_tool_called=True, kb_hit=False, # [AC-SCENE-SLOT-02] 场景槽位追踪 scene=request.scene, scene_slot_context=base_context.get("scene_slot_context"), missing_slots=[s.get("slot_key") for s in kb_missing_slots] if kb_missing_slots else None, ask_back_triggered=True, slot_sources=slot_state.slot_sources if slot_state else None, ), ) else: kb_result = await default_kb_tool_runner.execute( tenant_id=tenant_id, query=request.user_message, ) kb_success = kb_result.success kb_hits = kb_result.hits kb_fallback_reason = kb_result.fallback_reason_code runtime_observer.record_kb( request_id, tool_called=True, hit=kb_success and len(kb_hits) > 0, fallback_reason=kb_fallback_reason, ) if kb_success and kb_hits: kb_context = "\n".join([ f"[知识库] {hit.get('content', '')[:200]}" for hit in kb_hits[:3] ]) base_context["kb_context"] = kb_context logger.info( f"[AC-MARH-05] KB retrieval succeeded: hits={len(kb_hits)}" ) elif kb_fallback_reason: logger.warning( f"[AC-MARH-06] KB retrieval fallback: reason={kb_fallback_reason}" ) orchestrator = AgentOrchestrator( max_iterations=5, timeout_governor=timeout_governor, llm_client=llm_client, tool_registry=tool_registry, template_service=PromptTemplateService, variable_resolver=VariableResolver(), tenant_id=tenant_id, user_id=request.user_id, session_id=request.session_id, scene=request.scene, ) logger.info( f"[DEBUG-AGENT] Calling orchestrator.execute with: " f"user_message={request.user_message[:100] if request.user_message else 'None'}, " f"context_keys={list(base_context.keys())}, llm_client={llm_client is not None}" ) final_answer, react_ctx, agent_trace = await orchestrator.execute( user_message=request.user_message, context=base_context, ) logger.info( f"[DEBUG-AGENT] orchestrator.execute completed: " f"final_answer={final_answer[:200] if final_answer else 'None'}, " f"iterations={react_ctx.iteration}, tool_calls={len(react_ctx.tool_calls) if react_ctx.tool_calls else 0}" ) runtime_observer.record_react(request_id, react_ctx.iteration, react_ctx.tool_calls) trace_logger.update_trace( request_id=request_id, react_iterations=react_ctx.iteration, tool_calls=react_ctx.tool_calls, ) segments = _text_to_segments(final_answer) return DialogueResponse( segments=segments, trace=TraceInfo( mode=ExecutionMode.AGENT, request_id=trace.request_id, generation_id=trace.generation_id, react_iterations=react_ctx.iteration, tools_used=[tc.tool_name for tc in react_ctx.tool_calls] if react_ctx.tool_calls else None, tool_calls=react_ctx.tool_calls, timeout_profile=timeout_governor.profile, kb_tool_called=True, kb_hit=kb_success and len(kb_hits) > 0, fallback_reason_code=kb_fallback_reason, # [AC-SCENE-SLOT-02] 场景槽位追踪 scene=request.scene, scene_slot_context=base_context.get("scene_slot_context"), missing_slots=[s.get("slot_key") for s in kb_missing_slots] if kb_missing_slots else None, ask_back_triggered=False, slot_sources=slot_state.slot_sources if slot_state else None, kb_filter_sources=kb_dynamic_result.filter_sources if kb_dynamic_result else None, # [Step-KB-Binding] 步骤知识库绑定追踪 step_kb_binding=kb_dynamic_result.step_kb_binding if kb_dynamic_result else step_kb_binding_trace, ), ) async def _generate_ask_back_for_missing_slots( slot_state: Any, missing_slots: list[dict[str, str]], session: AsyncSession, tenant_id: str, session_id: str | None = None, max_ask_back_slots: int = 2, scene_slot_context: Any = None, # [AC-SCENE-SLOT-02] 场景槽位上下文 ) -> str: """ [AC-MRS-SLOT-META-03, AC-MRS-SLOT-ASKBACK-01] 为缺失的必填槽位生成追问响应 [AC-SCENE-SLOT-02] 支持场景槽位包配置的追问策略 支持批量追问多个缺失槽位,优先追问必填槽位 """ if not missing_slots: return "请提供更多信息以便我更好地帮助您。" # [AC-SCENE-SLOT-02] 如果有场景槽位上下文,使用场景配置的追问策略 if scene_slot_context: ask_back_order = getattr(scene_slot_context, 'ask_back_order', 'priority') if ask_back_order == "parallel": prompts = [] for missing in missing_slots[:max_ask_back_slots]: ask_back_prompt = missing.get("ask_back_prompt") if ask_back_prompt: prompts.append(ask_back_prompt) else: slot_key = missing.get("slot_key", "相关信息") prompts.append(f"您的{slot_key}") if len(prompts) == 1: return prompts[0] elif len(prompts) == 2: return f"为了更好地为您服务,请告诉我{prompts[0]}和{prompts[1]}。" else: all_but_last = "、".join(prompts[:-1]) return f"为了更好地为您服务,请告诉我{all_but_last},以及{prompts[-1]}。" if len(missing_slots) == 1: first_missing = missing_slots[0] ask_back_prompt = first_missing.get("ask_back_prompt") if ask_back_prompt: return ask_back_prompt slot_key = first_missing.get("slot_key", "相关信息") label = first_missing.get("label", slot_key) return f"为了更好地为您提供帮助,请告诉我您的{label}。" try: from app.services.mid.batch_ask_back_service import ( BatchAskBackConfig, BatchAskBackService, ) config = BatchAskBackConfig( max_ask_back_slots_per_turn=max_ask_back_slots, prefer_required=True, merge_prompts=True, ) ask_back_service = BatchAskBackService( session=session, tenant_id=tenant_id, session_id=session_id or "", config=config, ) result = await ask_back_service.generate_batch_ask_back( missing_slots=missing_slots, ) if result.has_ask_back(): return result.get_prompt() except Exception as e: logger.warning(f"[AC-MRS-SLOT-ASKBACK-01] Batch ask-back failed, fallback to single: {e}") prompts = [] for missing in missing_slots[:max_ask_back_slots]: ask_back_prompt = missing.get("ask_back_prompt") if ask_back_prompt: prompts.append(ask_back_prompt) else: label = missing.get("label", missing.get("slot_key", "相关信息")) prompts.append(f"您的{label}") if len(prompts) == 1: return prompts[0] elif len(prompts) == 2: return f"为了更好地为您服务,请告诉我{prompts[0]}和{prompts[1]}。" else: all_but_last = "、".join(prompts[:-1]) return f"为了更好地为您服务,请告诉我{all_but_last},以及{prompts[-1]}。" async def _execute_micro_flow_mode( tenant_id: str, request: DialogueRequest, router_result: Any, trace: TraceInfo, trace_logger: TraceLogger, session: AsyncSession, start_time: float, ) -> DialogueResponse: """Execute micro flow mode.""" if router_result.target_flow_id: try: from app.services.flow.engine import FlowEngine flow_engine = FlowEngine(session) instance, first_step = await flow_engine.start( tenant_id=tenant_id, session_id=request.session_id, flow_id=router_result.target_flow_id, ) if first_step: return DialogueResponse( segments=_text_to_segments(first_step), trace=TraceInfo( mode=ExecutionMode.MICRO_FLOW, request_id=trace.request_id, generation_id=trace.generation_id, intent=router_result.intent, ), ) except Exception as e: logger.warning(f"[AC-IDMP-05] Micro flow start failed: {e}") return DialogueResponse( segments=[Segment( text="正在为您处理,请稍候...", delay_after=0, )], trace=TraceInfo( mode=ExecutionMode.MICRO_FLOW, request_id=trace.request_id, generation_id=trace.generation_id, intent=router_result.intent, fallback_reason_code=router_result.fallback_reason_code, ), ) async def _execute_fixed_mode( tenant_id: str, request: DialogueRequest, router_result: Any, trace: TraceInfo, trace_logger: TraceLogger, start_time: float, ) -> DialogueResponse: """Execute fixed reply mode.""" text = router_result.fixed_reply or "收到您的消息,我们会尽快处理。" return DialogueResponse( segments=_text_to_segments(text), trace=TraceInfo( mode=ExecutionMode.FIXED, request_id=trace.request_id, generation_id=trace.generation_id, intent=router_result.intent, fallback_reason_code=router_result.fallback_reason_code, ), ) async def _execute_transfer_mode( tenant_id: str, request: DialogueRequest, router_result: Any, trace: TraceInfo, trace_logger: TraceLogger, start_time: float, ) -> DialogueResponse: """Execute transfer to human mode.""" text = router_result.transfer_message or "正在为您转接人工客服,请稍候..." return DialogueResponse( segments=_text_to_segments(text), trace=TraceInfo( mode=ExecutionMode.TRANSFER, request_id=trace.request_id, generation_id=trace.generation_id, intent=router_result.intent, fallback_reason_code=router_result.fallback_reason_code, ), ) def _text_to_segments(text: str) -> list[Segment]: """Convert text to segments.""" paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()] if not paragraphs: paragraphs = [text] return [ Segment(text=p, delay_after=100 if i < len(paragraphs) - 1 else 0) for i, p in enumerate(paragraphs) ]