ai-robot-core/ai-service/app/api/mid/dialogue.py

1724 lines
62 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Dialogue Controller for Mid Platform.
[AC-MARH-01, AC-MARH-02, AC-MARH-03, AC-MARH-04, AC-MARH-05, AC-MARH-06,
AC-MARH-07, AC-MARH-08, AC-MARH-09, AC-MARH-10, AC-MARH-11, AC-MARH-12]
Core endpoint: POST /mid/dialogue/respond
"""
import logging
import time
import uuid
from typing import Annotated, Any
from fastapi import APIRouter, Depends, Request
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session
from app.core.tenant import get_tenant_id
from app.models.mid.schemas import (
DialogueRequest,
DialogueResponse,
ExecutionMode,
Segment,
TraceInfo,
)
from app.services.mid.agent_orchestrator import AgentOrchestrator
from app.services.mid.default_kb_tool_runner import DefaultKbToolRunner
from app.services.mid.feature_flags import FeatureFlagService
from app.services.mid.high_risk_handler import HighRiskHandler
from app.services.mid.interrupt_context_enricher import InterruptContextEnricher
from app.services.mid.kb_search_dynamic_tool import (
KbSearchDynamicConfig,
KbSearchDynamicTool,
)
from app.services.mid.high_risk_check_tool import (
HighRiskCheckConfig,
HighRiskCheckTool,
register_high_risk_check_tool,
)
from app.services.mid.intent_hint_tool import (
IntentHintConfig,
IntentHintTool,
register_intent_hint_tool,
)
from app.services.mid.memory_recall_tool import (
MemoryRecallConfig,
MemoryRecallTool,
register_memory_recall_tool,
)
from app.services.mid.metrics_collector import MetricsCollector
from app.services.mid.output_guardrail_executor import OutputGuardrailExecutor
from app.services.mid.policy_router import IntentMatch, PolicyRouter
from app.services.mid.runtime_observer import RuntimeObserver
from app.services.mid.segment_humanizer import HumanizeConfig, SegmentHumanizer
from app.services.mid.timeout_governor import TimeoutGovernor
from app.services.mid.tool_registry import ToolRegistry
from app.services.mid.trace_logger import TraceLogger
from app.services.prompt.template_service import PromptTemplateService
from app.services.prompt.variable_resolver import VariableResolver
from app.services.intent.clarification import (
ClarificationEngine,
ClarifyReason,
ClarifySessionManager,
ClarifyState,
HybridIntentResult,
IntentCandidate,
T_HIGH,
T_LOW,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/mid", tags=["Mid Platform Dialogue"])
_mid_services: dict[str, Any] = {}
def get_policy_router() -> PolicyRouter:
"""Get or create PolicyRouter instance."""
if "policy_router" not in _mid_services:
_mid_services["policy_router"] = PolicyRouter()
return _mid_services["policy_router"]
def get_high_risk_handler() -> HighRiskHandler:
"""Get or create HighRiskHandler instance."""
if "high_risk_handler" not in _mid_services:
_mid_services["high_risk_handler"] = HighRiskHandler()
return _mid_services["high_risk_handler"]
def get_timeout_governor() -> TimeoutGovernor:
"""Get or create TimeoutGovernor instance."""
if "timeout_governor" not in _mid_services:
_mid_services["timeout_governor"] = TimeoutGovernor()
return _mid_services["timeout_governor"]
def get_feature_flag_service() -> FeatureFlagService:
"""Get or create FeatureFlagService instance."""
if "feature_flag_service" not in _mid_services:
_mid_services["feature_flag_service"] = FeatureFlagService()
return _mid_services["feature_flag_service"]
def get_trace_logger() -> TraceLogger:
"""Get or create TraceLogger instance."""
if "trace_logger" not in _mid_services:
_mid_services["trace_logger"] = TraceLogger()
return _mid_services["trace_logger"]
def get_metrics_collector() -> MetricsCollector:
"""Get or create MetricsCollector instance."""
if "metrics_collector" not in _mid_services:
_mid_services["metrics_collector"] = MetricsCollector()
return _mid_services["metrics_collector"]
def get_tool_registry() -> ToolRegistry:
"""Get or create ToolRegistry instance."""
if "tool_registry" not in _mid_services:
_mid_services["tool_registry"] = ToolRegistry(
timeout_governor=get_timeout_governor()
)
return _mid_services["tool_registry"]
_kb_search_dynamic_registered: bool = False
_intent_hint_registered: bool = False
_high_risk_check_registered: bool = False
_memory_recall_registered: bool = False
_metadata_discovery_registered: bool = False
def ensure_kb_search_dynamic_registered(
registry: ToolRegistry,
session: AsyncSession,
) -> None:
"""[AC-MARH-05] Ensure kb_search_dynamic tool is registered."""
global _kb_search_dynamic_registered
if _kb_search_dynamic_registered:
return
from app.services.mid.kb_search_dynamic_tool import register_kb_search_dynamic_tool
config = KbSearchDynamicConfig(
enabled=True,
top_k=5,
timeout_ms=10000,
min_score_threshold=0.5,
)
register_kb_search_dynamic_tool(
registry=registry,
session=session,
timeout_governor=get_timeout_governor(),
config=config,
)
_kb_search_dynamic_registered = True
logger.info("[AC-MARH-05] kb_search_dynamic tool registered to registry")
def ensure_intent_hint_registered(
registry: ToolRegistry,
session: AsyncSession,
) -> None:
"""[AC-IDMP-02, AC-IDMP-16] Ensure intent_hint tool is registered."""
global _intent_hint_registered
if _intent_hint_registered:
return
config = IntentHintConfig(
enabled=True,
timeout_ms=500,
top_n=3,
low_confidence_threshold=0.3,
)
register_intent_hint_tool(
registry=registry,
session=session,
config=config,
)
_intent_hint_registered = True
logger.info("[AC-IDMP-02] intent_hint tool registered to registry")
def ensure_high_risk_check_registered(
registry: ToolRegistry,
session: AsyncSession,
) -> None:
"""[AC-IDMP-05, AC-IDMP-20] Ensure high_risk_check tool is registered."""
global _high_risk_check_registered
if _high_risk_check_registered:
return
config = HighRiskCheckConfig(
enabled=True,
timeout_ms=500,
default_confidence=0.9,
)
register_high_risk_check_tool(
registry=registry,
session=session,
config=config,
)
_high_risk_check_registered = True
logger.info("[AC-IDMP-05] high_risk_check tool registered to registry")
def ensure_memory_recall_registered(
registry: ToolRegistry,
session: AsyncSession,
) -> None:
"""[AC-IDMP-13] Ensure memory_recall tool is registered."""
global _memory_recall_registered
if _memory_recall_registered:
return
config = MemoryRecallConfig(
enabled=True,
timeout_ms=1000,
max_recent_messages=8,
)
register_memory_recall_tool(
registry=registry,
session=session,
timeout_governor=get_timeout_governor(),
config=config,
)
_memory_recall_registered = True
logger.info("[AC-IDMP-13] memory_recall tool registered to registry")
def ensure_metadata_discovery_registered(
registry: ToolRegistry,
session: AsyncSession,
) -> None:
"""Ensure metadata_discovery tool is registered."""
global _metadata_discovery_registered
if _metadata_discovery_registered:
return
from app.services.mid.metadata_discovery_tool import register_metadata_discovery_tool
register_metadata_discovery_tool(
registry=registry,
session=session,
timeout_governor=get_timeout_governor(),
)
_metadata_discovery_registered = True
logger.info("[MetadataDiscovery] metadata_discovery tool registered to registry")
def get_output_guardrail_executor() -> OutputGuardrailExecutor:
"""Get or create OutputGuardrailExecutor instance."""
if "output_guardrail_executor" not in _mid_services:
_mid_services["output_guardrail_executor"] = OutputGuardrailExecutor()
return _mid_services["output_guardrail_executor"]
def get_interrupt_context_enricher() -> InterruptContextEnricher:
"""Get or create InterruptContextEnricher instance."""
if "interrupt_context_enricher" not in _mid_services:
_mid_services["interrupt_context_enricher"] = InterruptContextEnricher()
return _mid_services["interrupt_context_enricher"]
def get_default_kb_tool_runner() -> DefaultKbToolRunner:
"""Get or create DefaultKbToolRunner instance."""
if "default_kb_tool_runner" not in _mid_services:
_mid_services["default_kb_tool_runner"] = DefaultKbToolRunner(
timeout_governor=get_timeout_governor()
)
return _mid_services["default_kb_tool_runner"]
def get_segment_humanizer() -> SegmentHumanizer:
"""Get or create SegmentHumanizer instance."""
if "segment_humanizer" not in _mid_services:
_mid_services["segment_humanizer"] = SegmentHumanizer()
return _mid_services["segment_humanizer"]
def get_runtime_observer() -> RuntimeObserver:
"""Get or create RuntimeObserver instance."""
if "runtime_observer" not in _mid_services:
_mid_services["runtime_observer"] = RuntimeObserver()
return _mid_services["runtime_observer"]
def get_clarification_engine() -> ClarificationEngine:
"""Get or create ClarificationEngine instance."""
if "clarification_engine" not in _mid_services:
_mid_services["clarification_engine"] = ClarificationEngine(
t_high=T_HIGH,
t_low=T_LOW,
)
return _mid_services["clarification_engine"]
@router.post(
"/dialogue/respond",
operation_id="respondDialogue",
summary="Generate mid platform dialogue response",
description="""
[AC-MARH-01~12] Core dialogue response endpoint for mid platform.
Returns segments[] with trace info including:
- guardrail_triggered, guardrail_rule_id
- interrupt_consumed
- kb_tool_called, kb_hit
- timeout_profile, segment_stats
""",
)
async def respond_dialogue(
request: Request,
dialogue_request: DialogueRequest,
session: Annotated[AsyncSession, Depends(get_session)],
policy_router: PolicyRouter = Depends(get_policy_router),
high_risk_handler: HighRiskHandler = Depends(get_high_risk_handler),
timeout_governor: TimeoutGovernor = Depends(get_timeout_governor),
feature_flag_service: FeatureFlagService = Depends(get_feature_flag_service),
trace_logger: TraceLogger = Depends(get_trace_logger),
metrics_collector: MetricsCollector = Depends(get_metrics_collector),
output_guardrail_executor: OutputGuardrailExecutor = Depends(get_output_guardrail_executor),
interrupt_context_enricher: InterruptContextEnricher = Depends(get_interrupt_context_enricher),
default_kb_tool_runner: DefaultKbToolRunner = Depends(get_default_kb_tool_runner),
segment_humanizer: SegmentHumanizer = Depends(get_segment_humanizer),
runtime_observer: RuntimeObserver = Depends(get_runtime_observer),
clarification_engine: ClarificationEngine = Depends(get_clarification_engine),
) -> DialogueResponse:
"""
[AC-MARH-01~12] Generate dialogue response with segments and trace.
Flow:
1. Validate request and get tenant context
2. Start runtime observation
3. Process interrupted segments (AC-MARH-03/04)
4. Check feature flags for grayscale/rollback
5. Detect high-risk scenarios
6. Route to appropriate execution mode
7. For Agent mode: call KB tool (AC-MARH-05/06)
8. Execute output guardrail (AC-MARH-01/02)
9. Apply segment humanizer (AC-MARH-10/11)
10. Collect trace and return (AC-MARH-12)
"""
start_time = time.time()
tenant_id = get_tenant_id()
if not tenant_id:
from app.core.exceptions import MissingTenantIdException
raise MissingTenantIdException()
request_id = str(uuid.uuid4())
generation_id = str(uuid.uuid4())
logger.info(
f"[AC-MARH-01] Dialogue request: tenant={tenant_id}, "
f"session={dialogue_request.session_id}, request_id={request_id}, "
f"user_message={dialogue_request.user_message[:100] if dialogue_request.user_message else 'None'}..."
)
runtime_ctx = runtime_observer.start_observation(
tenant_id=tenant_id,
session_id=dialogue_request.session_id,
request_id=request_id,
generation_id=generation_id,
)
trace = trace_logger.start_trace(
tenant_id=tenant_id,
session_id=dialogue_request.session_id,
request_id=request_id,
generation_id=generation_id,
)
metrics_collector.start_session(dialogue_request.session_id)
tool_registry = get_tool_registry()
ensure_kb_search_dynamic_registered(tool_registry, session)
ensure_intent_hint_registered(tool_registry, session)
ensure_high_risk_check_registered(tool_registry, session)
ensure_memory_recall_registered(tool_registry, session)
ensure_metadata_discovery_registered(tool_registry, session)
try:
interrupt_ctx = interrupt_context_enricher.enrich(
dialogue_request.interrupted_segments,
generation_id,
)
runtime_observer.record_interrupt(request_id, interrupt_ctx.consumed)
feature_flags = dialogue_request.feature_flags or feature_flag_service.get_flags(
dialogue_request.session_id
)
if feature_flags.rollback_to_legacy:
logger.info(f"[AC-MARH-17] Rollback to legacy for session: {dialogue_request.session_id}")
return await _handle_legacy_response(
tenant_id=tenant_id,
request=dialogue_request,
trace=trace,
trace_logger=trace_logger,
start_time=start_time,
)
high_risk_check_tool = HighRiskCheckTool(
session=session,
config=HighRiskCheckConfig(enabled=True, timeout_ms=500),
)
high_risk_result = await high_risk_check_tool.execute(
message=dialogue_request.user_message,
tenant_id=tenant_id,
)
if high_risk_result.duration_ms > 0:
hr_trace = high_risk_check_tool.create_trace(high_risk_result, tenant_id)
trace_logger.update_trace(
request_id=request_id,
tool_calls=[hr_trace],
)
logger.info(
f"[AC-IDMP-05, AC-IDMP-20] High risk check result: "
f"matched={high_risk_result.matched}, scenario={high_risk_result.risk_scenario}, "
f"duration_ms={high_risk_result.duration_ms}"
)
if high_risk_result.matched and high_risk_result.risk_scenario:
logger.info(
f"[AC-IDMP-05] High-risk matched from tool: {high_risk_result.risk_scenario.value}"
)
return await _handle_high_risk_check_response(
tenant_id=tenant_id,
request=dialogue_request,
high_risk_result=high_risk_result,
trace=trace,
trace_logger=trace_logger,
start_time=start_time,
session=session,
)
intent_hint_tool = IntentHintTool(
session=session,
config=IntentHintConfig(enabled=True, timeout_ms=500),
)
intent_hint = await intent_hint_tool.execute(
message=dialogue_request.user_message,
tenant_id=tenant_id,
history=[h.model_dump() for h in dialogue_request.history] if dialogue_request.history else None,
)
if intent_hint.duration_ms > 0:
hint_trace = intent_hint_tool.create_trace(intent_hint)
trace_logger.update_trace(
request_id=request_id,
tool_calls=[hint_trace],
)
logger.info(
f"[AC-IDMP-02] Intent hint result: intent={intent_hint.intent}, "
f"confidence={intent_hint.confidence}, suggested_mode={intent_hint.suggested_mode}"
)
intent_match = await _match_intent(tenant_id, dialogue_request, session)
router_result = policy_router.route(
user_message=dialogue_request.user_message,
session_mode="BOT_ACTIVE",
feature_flags=feature_flags,
intent_match=intent_match,
intent_hint=intent_hint,
)
runtime_observer.update_mode(request_id, router_result.mode, router_result.intent)
runtime_observer.record_timeout_profile(request_id, timeout_governor.profile)
trace_logger.update_trace(
request_id=request_id,
mode=router_result.mode,
intent=router_result.intent,
fallback_reason_code=router_result.fallback_reason_code,
)
if router_result.mode == ExecutionMode.AGENT:
response = await _execute_agent_mode(
tenant_id=tenant_id,
request=dialogue_request,
request_id=request_id,
trace=trace,
trace_logger=trace_logger,
timeout_governor=timeout_governor,
metrics_collector=metrics_collector,
default_kb_tool_runner=default_kb_tool_runner,
runtime_observer=runtime_observer,
interrupt_ctx=interrupt_ctx,
start_time=start_time,
session=session,
tool_registry=tool_registry,
)
elif router_result.mode == ExecutionMode.MICRO_FLOW:
response = await _execute_micro_flow_mode(
tenant_id=tenant_id,
request=dialogue_request,
router_result=router_result,
trace=trace,
trace_logger=trace_logger,
session=session,
start_time=start_time,
)
elif router_result.mode == ExecutionMode.FIXED:
response = await _execute_fixed_mode(
tenant_id=tenant_id,
request=dialogue_request,
router_result=router_result,
trace=trace,
trace_logger=trace_logger,
start_time=start_time,
)
else:
response = await _execute_transfer_mode(
tenant_id=tenant_id,
request=dialogue_request,
router_result=router_result,
trace=trace,
trace_logger=trace_logger,
start_time=start_time,
)
filtered_segments, guardrail_result = await output_guardrail_executor.filter_segments(
response.segments, tenant_id
)
runtime_observer.record_guardrail(
request_id, guardrail_result.triggered, guardrail_result.rule_id
)
humanize_config = None
if dialogue_request.humanize_config:
humanize_config = HumanizeConfig(
enabled=dialogue_request.humanize_config.enabled or True,
min_delay_ms=dialogue_request.humanize_config.min_delay_ms or 50,
max_delay_ms=dialogue_request.humanize_config.max_delay_ms or 500,
length_bucket_strategy=dialogue_request.humanize_config.length_bucket_strategy or "simple",
)
final_segments, segment_stats = segment_humanizer.humanize(
"\n\n".join(s.text for s in filtered_segments),
humanize_config,
)
runtime_observer.record_segment_stats(request_id, segment_stats)
final_trace = runtime_observer.end_observation(request_id)
final_trace.segment_stats = segment_stats
final_trace.guardrail_triggered = guardrail_result.triggered
final_trace.guardrail_rule_id = guardrail_result.rule_id
latency_ms = int((time.time() - start_time) * 1000)
metrics_collector.record_turn(
session_id=dialogue_request.session_id,
tenant_id=tenant_id,
latency_ms=latency_ms,
task_completed=True,
)
audit = trace_logger.end_trace(
request_id=request_id,
tenant_id=tenant_id,
session_id=dialogue_request.session_id,
latency_ms=latency_ms,
)
logger.info(
f"[AC-MARH-12] Audit record: request_id={request_id}, "
f"mode={final_trace.mode.value}, latency_ms={latency_ms}, "
f"guardrail={guardrail_result.triggered}, kb_hit={final_trace.kb_hit}"
)
if dialogue_request.user_id:
try:
from app.services.mid.memory_adapter import MemoryAdapter
from app.services.mid.memory_summary_generator import MemorySummaryGenerator
memory_adapter = MemoryAdapter(session=session)
summary_generator = MemorySummaryGenerator()
history_messages = [
{"role": h.role, "content": h.content}
for h in (dialogue_request.history or [])
]
assistant_reply = "\n".join(s.text for s in final_segments)
update_messages = history_messages + [
{"role": "user", "content": dialogue_request.user_message},
{"role": "assistant", "content": assistant_reply},
]
await memory_adapter.update_with_summary_generation(
user_id=dialogue_request.user_id,
session_id=dialogue_request.session_id,
messages=update_messages,
tenant_id=tenant_id,
summary_generator=summary_generator,
recent_turns=8,
)
except Exception as e:
logger.warning(f"[AC-IDMP-14] Memory update trigger failed: {e}")
return DialogueResponse(
segments=final_segments,
trace=final_trace,
)
except Exception as e:
latency_ms = int((time.time() - start_time) * 1000)
import traceback
logger.error(
f"[AC-IDMP-06] Dialogue error: {e}\n"
f"Exception type: {type(e).__name__}\n"
f"Request details: session_id={dialogue_request.session_id}, "
f"user_message={dialogue_request.user_message[:200] if dialogue_request.user_message else 'None'}, "
f"scene={dialogue_request.scene}, user_id={dialogue_request.user_id}\n"
f"Traceback:\n{traceback.format_exc()}"
)
trace_logger.update_trace(
request_id=request_id,
mode=ExecutionMode.FIXED,
fallback_reason_code=f"error: {str(e)[:50]}",
)
trace_logger.end_trace(
request_id=request_id,
tenant_id=tenant_id,
session_id=dialogue_request.session_id,
latency_ms=latency_ms,
)
return DialogueResponse(
segments=[Segment(
text="抱歉,服务暂时不可用,请稍后重试或联系人工客服。",
delay_after=0,
)],
trace=TraceInfo(
mode=ExecutionMode.FIXED,
request_id=request_id,
generation_id=generation_id,
fallback_reason_code="service_error",
),
)
async def _match_intent(
tenant_id: str,
request: DialogueRequest,
session: AsyncSession,
) -> IntentMatch | None:
"""Match intent from user message."""
try:
from app.services.intent.router import IntentRouter
from app.services.intent.rule_service import IntentRuleService
rule_service = IntentRuleService(session)
rules = await rule_service.get_enabled_rules_for_matching(tenant_id)
if not rules:
return None
router = IntentRouter()
result = router.match(request.user_message, rules)
if result:
return IntentMatch(
intent_id=str(result.rule.id),
intent_name=result.rule.name,
confidence=1.0,
response_type=result.rule.response_type,
target_kb_ids=result.rule.target_kb_ids,
flow_id=str(result.rule.flow_id) if result.rule.flow_id else None,
fixed_reply=result.rule.fixed_reply,
transfer_message=result.rule.transfer_message,
)
return None
except Exception as e:
logger.warning(f"[AC-IDMP-02] Intent match failed: {e}")
return None
async def _match_intent_hybrid(
tenant_id: str,
request: DialogueRequest,
session: AsyncSession,
clarification_engine: ClarificationEngine,
) -> HybridIntentResult:
"""
[AC-CLARIFY] Hybrid intent matching with clarification support.
Returns HybridIntentResult with unified confidence and candidates.
"""
try:
from app.services.intent.router import IntentRouter
from app.services.intent.rule_service import IntentRuleService
rule_service = IntentRuleService(session)
rules = await rule_service.get_enabled_rules_for_matching(tenant_id)
if not rules:
return HybridIntentResult(
intent=None,
confidence=0.0,
candidates=[],
)
router = IntentRouter()
try:
fusion_result = await router.match_hybrid(
message=request.user_message,
rules=rules,
tenant_id=tenant_id,
)
hybrid_result = HybridIntentResult.from_fusion_result(fusion_result)
logger.info(
f"[AC-CLARIFY] Hybrid intent match: "
f"intent={hybrid_result.intent.intent_name if hybrid_result.intent else None}, "
f"confidence={hybrid_result.confidence:.3f}, "
f"need_clarify={hybrid_result.need_clarify}"
)
return hybrid_result
except Exception as e:
logger.warning(f"[AC-CLARIFY] Hybrid match failed, fallback to rule: {e}")
result = router.match(request.user_message, rules)
if result:
confidence = clarification_engine.compute_confidence(
rule_score=1.0,
semantic_score=0.0,
llm_score=0.0,
)
candidate = IntentCandidate(
intent_id=str(result.rule.id),
intent_name=result.rule.name,
confidence=confidence,
response_type=result.rule.response_type,
target_kb_ids=result.rule.target_kb_ids,
flow_id=str(result.rule.flow_id) if result.rule.flow_id else None,
fixed_reply=result.rule.fixed_reply,
transfer_message=result.rule.transfer_message,
)
return HybridIntentResult(
intent=candidate,
confidence=confidence,
candidates=[candidate],
)
return HybridIntentResult(
intent=None,
confidence=0.0,
candidates=[],
)
except Exception as e:
logger.warning(f"[AC-CLARIFY] Intent match failed: {e}")
return HybridIntentResult(
intent=None,
confidence=0.0,
candidates=[],
)
async def _handle_clarification(
tenant_id: str,
request: DialogueRequest,
hybrid_result: HybridIntentResult,
clarification_engine: ClarificationEngine,
trace: TraceInfo,
session: AsyncSession,
required_slots: list[str] | None = None,
filled_slots: dict[str, Any] | None = None,
) -> DialogueResponse | None:
"""
[AC-CLARIFY] Handle clarification logic.
Returns DialogueResponse if clarification is needed, None otherwise.
"""
existing_state = ClarifySessionManager.get_session(request.session_id)
if existing_state and not existing_state.is_max_retry():
logger.info(
f"[AC-CLARIFY] Processing clarify response: "
f"session={request.session_id}, retry={existing_state.retry_count}"
)
new_result = clarification_engine.process_clarify_response(
user_message=request.user_message,
state=existing_state,
)
if not new_result.need_clarify:
ClarifySessionManager.clear_session(request.session_id)
if new_result.intent:
return None
clarify_prompt = clarification_engine.generate_clarify_prompt(existing_state)
return DialogueResponse(
segments=[Segment(text=clarify_prompt, delay_after=0)],
trace=TraceInfo(
mode=ExecutionMode.FIXED,
request_id=trace.request_id,
generation_id=trace.generation_id,
fallback_reason_code=f"clarify_retry_{existing_state.retry_count}",
intent=existing_state.candidates[0].intent_name if existing_state.candidates else None,
),
)
should_clarify, clarify_state = clarification_engine.should_trigger_clarify(
result=hybrid_result,
required_slots=required_slots,
filled_slots=filled_slots,
)
if not should_clarify or not clarify_state:
return None
logger.info(
f"[AC-CLARIFY] Clarification triggered: "
f"session={request.session_id}, reason={clarify_state.reason}, "
f"confidence={hybrid_result.confidence:.3f}"
)
ClarifySessionManager.set_session(request.session_id, clarify_state)
clarify_prompt = clarification_engine.generate_clarify_prompt(clarify_state)
return DialogueResponse(
segments=[Segment(text=clarify_prompt, delay_after=0)],
trace=TraceInfo(
mode=ExecutionMode.FIXED,
request_id=trace.request_id,
generation_id=trace.generation_id,
fallback_reason_code=f"clarify_{clarify_state.reason.value}",
intent=hybrid_result.intent.intent_name if hybrid_result.intent else None,
),
)
async def _check_hard_block(
tenant_id: str,
request: DialogueRequest,
hybrid_result: HybridIntentResult,
clarification_engine: ClarificationEngine,
trace: TraceInfo,
required_slots: list[str] | None = None,
filled_slots: dict[str, Any] | None = None,
) -> DialogueResponse | None:
"""
[AC-CLARIFY] Check hard block conditions.
Hard blocks:
1. confidence < T_high: block entering new flow
2. required_slots missing: block flow progression
Returns DialogueResponse if blocked, None otherwise.
"""
is_blocked, block_reason = clarification_engine.check_hard_block(
result=hybrid_result,
required_slots=required_slots,
filled_slots=filled_slots,
)
if not is_blocked:
return None
logger.info(
f"[AC-CLARIFY] Hard block triggered: "
f"session={request.session_id}, reason={block_reason}, "
f"confidence={hybrid_result.confidence:.3f}"
)
clarify_state = ClarifyState(
reason=block_reason,
asked_slot=required_slots[0] if required_slots and block_reason == ClarifyReason.MISSING_SLOT else None,
candidates=hybrid_result.candidates,
)
ClarifySessionManager.set_session(request.session_id, clarify_state)
clarification_engine._metrics.record_clarify_trigger()
clarify_prompt = clarification_engine.generate_clarify_prompt(clarify_state)
return DialogueResponse(
segments=[Segment(text=clarify_prompt, delay_after=0)],
trace=TraceInfo(
mode=ExecutionMode.FIXED,
request_id=trace.request_id,
generation_id=trace.generation_id,
fallback_reason_code=f"hard_block_{block_reason.value}",
intent=hybrid_result.intent.intent_name if hybrid_result.intent else None,
),
)
async def _handle_legacy_response(
tenant_id: str,
request: DialogueRequest,
trace: TraceInfo,
trace_logger: TraceLogger,
start_time: float,
) -> DialogueResponse:
"""Handle rollback to legacy pipeline."""
latency_ms = int((time.time() - start_time) * 1000)
return DialogueResponse(
segments=[Segment(
text="正在使用传统模式处理您的请求...",
delay_after=0,
)],
trace=TraceInfo(
mode=ExecutionMode.FIXED,
request_id=trace.request_id,
generation_id=trace.generation_id,
fallback_reason_code="rollback_to_legacy",
),
)
async def _handle_high_risk_check_response(
tenant_id: str,
request: DialogueRequest,
high_risk_result: Any,
trace: TraceInfo,
trace_logger: TraceLogger,
start_time: float,
session: AsyncSession,
) -> DialogueResponse:
"""
[AC-IDMP-05, AC-IDMP-20] Handle high-risk scenario from high_risk_check tool.
高风险优先于普通意图路由。
"""
from app.models.mid.schemas import HighRiskCheckResult
if not isinstance(high_risk_result, HighRiskCheckResult):
high_risk_result = HighRiskCheckResult(**high_risk_result)
latency_ms = int((time.time() - start_time) * 1000)
recommended_mode = high_risk_result.recommended_mode or ExecutionMode.MICRO_FLOW
risk_scenario = high_risk_result.risk_scenario
trace_logger.update_trace(
request_id=trace.request_id or "",
mode=recommended_mode,
fallback_reason_code=f"high_risk_{risk_scenario.value if risk_scenario else 'unknown'}",
)
if recommended_mode == ExecutionMode.TRANSFER:
transfer_msg = "正在为您转接人工客服..."
if risk_scenario:
if risk_scenario.value == "complaint_escalation":
transfer_msg = "检测到您可能需要投诉处理,正在为您转接人工客服..."
elif risk_scenario.value == "refund":
transfer_msg = "您的退款请求需要人工处理,正在为您转接..."
return DialogueResponse(
segments=[Segment(
text=transfer_msg,
delay_after=0,
)],
trace=TraceInfo(
mode=ExecutionMode.TRANSFER,
request_id=trace.request_id,
generation_id=trace.generation_id,
high_risk_policy_set=[risk_scenario] if risk_scenario else None,
fallback_reason_code=high_risk_result.rule_id,
),
)
if high_risk_result.rule_id:
try:
from sqlalchemy import select
from app.models.entities import HighRiskPolicy
import uuid
stmt = select(HighRiskPolicy).where(
HighRiskPolicy.id == uuid.UUID(high_risk_result.rule_id)
)
result = await session.execute(stmt)
policy = result.scalar_one_or_none()
if policy and policy.flow_id:
return DialogueResponse(
segments=[Segment(
text="检测到您的请求需要特殊处理,正在为您安排...",
delay_after=0,
)],
trace=TraceInfo(
mode=ExecutionMode.MICRO_FLOW,
request_id=trace.request_id,
generation_id=trace.generation_id,
high_risk_policy_set=[risk_scenario] if risk_scenario else None,
fallback_reason_code=high_risk_result.rule_id,
),
)
except Exception as e:
logger.warning(f"[AC-IDMP-05] Failed to load high risk policy: {e}")
return DialogueResponse(
segments=[Segment(
text="检测到您的请求需要特殊处理,正在为您安排...",
delay_after=0,
)],
trace=TraceInfo(
mode=ExecutionMode.MICRO_FLOW,
request_id=trace.request_id,
generation_id=trace.generation_id,
high_risk_policy_set=[risk_scenario] if risk_scenario else None,
fallback_reason_code=high_risk_result.rule_id or f"high_risk_{risk_scenario.value if risk_scenario else 'unknown'}",
),
)
async def _handle_high_risk_response(
tenant_id: str,
request: DialogueRequest,
high_risk_match: Any,
high_risk_handler: HighRiskHandler,
trace: TraceInfo,
trace_logger: TraceLogger,
start_time: float,
) -> DialogueResponse:
"""Handle high-risk scenario response."""
router_result = high_risk_handler.handle(high_risk_match)
latency_ms = int((time.time() - start_time) * 1000)
trace_logger.update_trace(
request_id=trace.request_id or "",
mode=router_result.mode,
fallback_reason_code=router_result.fallback_reason_code,
)
if router_result.mode == ExecutionMode.TRANSFER:
return DialogueResponse(
segments=[Segment(
text=router_result.transfer_message or "正在为您转接人工客服...",
delay_after=0,
)],
trace=TraceInfo(
mode=ExecutionMode.TRANSFER,
request_id=trace.request_id,
generation_id=trace.generation_id,
high_risk_policy_set=[high_risk_match.scenario],
fallback_reason_code=router_result.fallback_reason_code,
),
)
return DialogueResponse(
segments=[Segment(
text="检测到您的请求需要特殊处理,正在为您安排...",
delay_after=0,
)],
trace=TraceInfo(
mode=ExecutionMode.MICRO_FLOW,
request_id=trace.request_id,
generation_id=trace.generation_id,
high_risk_policy_set=[high_risk_match.scenario],
fallback_reason_code=router_result.fallback_reason_code,
),
)
async def _execute_agent_mode(
tenant_id: str,
request: DialogueRequest,
request_id: str,
trace: TraceInfo,
trace_logger: TraceLogger,
timeout_governor: TimeoutGovernor,
metrics_collector: MetricsCollector,
default_kb_tool_runner: DefaultKbToolRunner,
runtime_observer: RuntimeObserver,
interrupt_ctx: Any = None,
start_time: float = 0,
session: AsyncSession | None = None,
tool_registry: ToolRegistry | None = None,
) -> DialogueResponse:
"""
[AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-IDMP-13, AC-MRS-SLOT-META-03]
Execute agent mode with ReAct loop, KB tool, memory recall, and slot state aggregation.
"""
from app.services.llm.factory import get_llm_config_manager
from app.services.mid.slot_state_aggregator import SlotStateAggregator
logger.info(
f"[DEBUG-AGENT] Starting _execute_agent_mode: tenant={tenant_id}, "
f"session={request.session_id}, scene={request.scene}, user_id={request.user_id}, "
f"user_message={request.user_message[:100] if request.user_message else 'None'}..."
)
try:
llm_manager = get_llm_config_manager()
llm_client = llm_manager.get_client()
logger.info(f"[DEBUG-AGENT] LLM client obtained successfully")
except Exception as e:
logger.warning(f"[AC-MARH-07] Failed to get LLM client: {e}")
llm_client = None
base_context = {"history": [h.model_dump() for h in request.history]} if request.history else {}
if request.scene:
base_context["scene"] = request.scene
if interrupt_ctx and interrupt_ctx.consumed:
base_context["interrupted_content"] = interrupt_ctx.interrupted_content
base_context["interrupted_segment_ids"] = interrupt_ctx.interrupted_segment_ids
logger.info(
f"[AC-MARH-03] Agent context enriched with interrupt: "
f"{len(interrupt_ctx.interrupted_content or '')} chars"
)
# [AC-MRS-SLOT-META-03] 初始化槽位状态聚合器
slot_state = None
memory_context = ""
memory_missing_slots: list[str] = []
scene_slot_context = None
flow_status = None
is_flow_runtime = False
if session:
from app.services.flow.engine import FlowEngine
flow_engine = FlowEngine(session)
flow_status = await flow_engine.get_flow_status(tenant_id, request.session_id)
is_flow_runtime = bool(flow_status and flow_status.get("status") == "active")
logger.info(
f"[AC-IDMP-02] 运行轨道判定: session={request.session_id}, "
f"is_flow_runtime={is_flow_runtime}, flow_id={flow_status.get('flow_id') if flow_status else None}, "
f"current_step={flow_status.get('current_step') if flow_status else None}"
)
# [AC-SCENE-SLOT-02] 加载场景槽位包(仅流程态加载)
if is_flow_runtime and session and request.scene:
from app.services.mid.scene_slot_bundle_loader import SceneSlotBundleLoader
scene_loader = SceneSlotBundleLoader(session)
scene_slot_context = await scene_loader.load_scene_context(
tenant_id=tenant_id,
scene_key=request.scene,
)
if scene_slot_context:
logger.info(
f"[AC-SCENE-SLOT-02] 场景槽位包加载成功: "
f"scene={request.scene}, required={len(scene_slot_context.required_slots)}, "
f"optional={len(scene_slot_context.optional_slots)}, "
f"threshold={scene_slot_context.completion_threshold}"
)
base_context["scene_slot_context"] = {
"scene_key": scene_slot_context.scene_key,
"scene_name": scene_slot_context.scene_name,
"required_slots": scene_slot_context.get_required_slot_keys(),
"optional_slots": scene_slot_context.get_optional_slot_keys(),
}
if session and request.user_id:
memory_recall_tool = MemoryRecallTool(
session=session,
timeout_governor=timeout_governor,
config=MemoryRecallConfig(
enabled=True,
timeout_ms=1000,
max_recent_messages=8,
),
)
memory_result = await memory_recall_tool.execute(
tenant_id=tenant_id,
user_id=request.user_id,
session_id=request.session_id,
)
if memory_result.duration_ms > 0:
memory_trace = memory_recall_tool.create_trace(memory_result, tenant_id)
trace_logger.update_trace(
request_id=request_id,
tool_calls=[memory_trace],
)
memory_context = memory_result.get_context_for_prompt()
memory_missing_slots = memory_result.missing_slots
if memory_context:
base_context["memory_context"] = memory_context
logger.info(
f"[AC-IDMP-13] Memory recall succeeded: "
f"profile={len(memory_result.profile)}, facts={len(memory_result.facts)}, "
f"slots={len(memory_result.slots)}, missing_slots={len(memory_missing_slots)}, "
f"duration_ms={memory_result.duration_ms}"
)
elif memory_result.fallback_reason_code:
logger.warning(
f"[AC-IDMP-13] Memory recall fallback: reason={memory_result.fallback_reason_code}"
)
# [AC-MRS-SLOT-META-03] 仅流程态执行槽位聚合/提取
if is_flow_runtime:
slot_aggregator = SlotStateAggregator(
session=session,
tenant_id=tenant_id,
session_id=request.session_id,
)
slot_state = await slot_aggregator.aggregate(
memory_slots=memory_result.slots,
current_input_slots=None, # 可从 request 中解析
context=base_context,
scene_slot_context=scene_slot_context, # [AC-SCENE-SLOT-02] 传入场景槽位上下文
)
logger.info(
f"[AC-MRS-SLOT-META-03] 流程态槽位聚合完成: "
f"filled={len(slot_state.filled_slots)}, "
f"missing={len(slot_state.missing_required_slots)}, "
f"mappings={slot_state.slot_to_field_map}"
)
# [AC-MRS-SLOT-EXTRACT-01] 自动提取槽位
if slot_state and slot_state.missing_required_slots:
from app.services.mid.slot_extraction_integration import SlotExtractionIntegration
extraction_integration = SlotExtractionIntegration(
session=session,
tenant_id=tenant_id,
session_id=request.session_id,
)
extraction_result = await extraction_integration.extract_and_fill(
user_input=request.user_message,
slot_state=slot_state,
)
if extraction_result.extracted_slots:
slot_state = await slot_aggregator.aggregate(
memory_slots=memory_result.slots,
current_input_slots=extraction_result.extracted_slots,
context=base_context,
)
logger.info(
f"[AC-MRS-SLOT-EXTRACT-01] 流程态自动提槽完成: "
f"extracted={list(extraction_result.extracted_slots.keys())}, "
f"time_ms={extraction_result.total_execution_time_ms:.2f}"
)
else:
logger.info(
f"[AC-MRS-SLOT-META-03] 当前为通用问答轨道,跳过槽位聚合与缺槽追问: "
f"session={request.session_id}"
)
kb_hits = []
kb_success = False
kb_fallback_reason = None
kb_applied_filter = {}
kb_missing_slots = []
kb_dynamic_result = None
step_kb_binding_trace: dict[str, Any] | None = None
# [Step-KB-Binding] 获取当前流程步骤的 KB 配置(仅流程态)
step_kb_config = None
if is_flow_runtime and session and flow_status:
current_step_no = flow_status.get("current_step")
flow_id = flow_status.get("flow_id")
if flow_id and current_step_no:
from app.models.entities import ScriptFlow
from sqlalchemy import select
stmt = select(ScriptFlow).where(ScriptFlow.id == flow_id)
result = await session.execute(stmt)
flow = result.scalar_one_or_none()
if flow and flow.steps:
step_idx = current_step_no - 1
if 0 <= step_idx < len(flow.steps):
current_step = flow.steps[step_idx]
# 构建 StepKbConfig
from app.services.mid.kb_search_dynamic_tool import StepKbConfig
step_kb_config = StepKbConfig(
allowed_kb_ids=current_step.get("allowed_kb_ids"),
preferred_kb_ids=current_step.get("preferred_kb_ids"),
kb_query_hint=current_step.get("kb_query_hint"),
max_kb_calls=current_step.get("max_kb_calls_per_step", 1),
step_id=f"{flow_id}_step_{current_step_no}",
)
step_kb_binding_trace = {
"flow_id": str(flow_id),
"flow_name": flow_status.get("flow_name"),
"current_step": current_step_no,
"step_id": step_kb_config.step_id,
"allowed_kb_ids": step_kb_config.allowed_kb_ids,
"preferred_kb_ids": step_kb_config.preferred_kb_ids,
}
logger.info(
f"[Step-KB-Binding] 步骤知识库配置加载成功: "
f"flow={flow_status.get('flow_name')}, step={current_step_no}, "
f"allowed_kb_ids={step_kb_config.allowed_kb_ids}"
)
if session and tool_registry:
kb_tool = KbSearchDynamicTool(
session=session,
timeout_governor=timeout_governor,
config=KbSearchDynamicConfig(
enabled=True,
top_k=5,
timeout_ms=10000,
min_score_threshold=0.5,
),
)
# [AC-MRS-SLOT-META-03] 传入 slot_state 进行 KB 检索
# [Step-KB-Binding] 传入 step_kb_config 进行步骤级别的 KB 约束
kb_dynamic_result = await kb_tool.execute(
query=request.user_message,
tenant_id=tenant_id,
top_k=5,
context=base_context,
slot_state=slot_state,
step_kb_config=step_kb_config,
slot_policy="flow_strict" if is_flow_runtime else "agent_relaxed",
)
kb_success = kb_dynamic_result.success
kb_hits = kb_dynamic_result.hits
kb_fallback_reason = kb_dynamic_result.fallback_reason_code
kb_applied_filter = kb_dynamic_result.applied_filter
kb_missing_slots = kb_dynamic_result.missing_required_slots
if kb_dynamic_result.tool_trace:
trace_logger.update_trace(
request_id=request_id,
tool_calls=[kb_dynamic_result.tool_trace],
)
logger.info(
f"[AC-MARH-05] KB动态检索完成: success={kb_success}, "
f"hits={len(kb_hits)}, filter={kb_applied_filter}, "
f"missing_slots={kb_missing_slots}, track={'flow' if is_flow_runtime else 'agent'}"
)
# [AC-MRS-SLOT-META-03] 处理缺失必填槽位 -> 追问闭环(仅流程态)
if is_flow_runtime and kb_fallback_reason == "MISSING_REQUIRED_SLOTS" and kb_missing_slots:
ask_back_text = await _generate_ask_back_for_missing_slots(
slot_state=slot_state,
missing_slots=kb_missing_slots,
session=session,
tenant_id=tenant_id,
session_id=request.session_id,
scene_slot_context=scene_slot_context, # [AC-SCENE-SLOT-02] 传入场景槽位上下文
)
logger.info(
f"[AC-MRS-SLOT-META-03] 流程态缺槽追问文案已生成: "
f"{ask_back_text[:50]}..."
)
return DialogueResponse(
segments=[Segment(text=ask_back_text, delay_after=0)],
trace=TraceInfo(
mode=ExecutionMode.AGENT,
request_id=trace.request_id,
generation_id=trace.generation_id,
fallback_reason_code="missing_required_slots",
kb_tool_called=True,
kb_hit=False,
# [AC-SCENE-SLOT-02] 场景槽位追踪
scene=request.scene,
scene_slot_context=base_context.get("scene_slot_context"),
missing_slots=[s.get("slot_key") for s in kb_missing_slots] if kb_missing_slots else None,
ask_back_triggered=True,
slot_sources=slot_state.slot_sources if slot_state else None,
),
)
else:
kb_result = await default_kb_tool_runner.execute(
tenant_id=tenant_id,
query=request.user_message,
)
kb_success = kb_result.success
kb_hits = kb_result.hits
kb_fallback_reason = kb_result.fallback_reason_code
runtime_observer.record_kb(
request_id,
tool_called=True,
hit=kb_success and len(kb_hits) > 0,
fallback_reason=kb_fallback_reason,
)
if kb_success and kb_hits:
kb_context = "\n".join([
f"[知识库] {hit.get('content', '')[:200]}"
for hit in kb_hits[:3]
])
base_context["kb_context"] = kb_context
logger.info(
f"[AC-MARH-05] KB retrieval succeeded: hits={len(kb_hits)}"
)
elif kb_fallback_reason:
logger.warning(
f"[AC-MARH-06] KB retrieval fallback: reason={kb_fallback_reason}"
)
orchestrator = AgentOrchestrator(
max_iterations=5,
timeout_governor=timeout_governor,
llm_client=llm_client,
tool_registry=tool_registry,
template_service=PromptTemplateService,
variable_resolver=VariableResolver(),
tenant_id=tenant_id,
user_id=request.user_id,
session_id=request.session_id,
scene=request.scene,
)
logger.info(
f"[DEBUG-AGENT] Calling orchestrator.execute with: "
f"user_message={request.user_message[:100] if request.user_message else 'None'}, "
f"context_keys={list(base_context.keys())}, llm_client={llm_client is not None}"
)
final_answer, react_ctx, agent_trace = await orchestrator.execute(
user_message=request.user_message,
context=base_context,
)
logger.info(
f"[DEBUG-AGENT] orchestrator.execute completed: "
f"final_answer={final_answer[:200] if final_answer else 'None'}, "
f"iterations={react_ctx.iteration}, tool_calls={len(react_ctx.tool_calls) if react_ctx.tool_calls else 0}"
)
runtime_observer.record_react(request_id, react_ctx.iteration, react_ctx.tool_calls)
# 合并 tool_calls优先使用 KB 工具内部的 trace包含注入后的参数
final_tool_calls = list(react_ctx.tool_calls) if react_ctx.tool_calls else []
logger.info(
f"[TRACE-MERGE] Before merge: final_tool_calls count={len(final_tool_calls)}, "
f"kb_dynamic_result exists={kb_dynamic_result is not None}, "
f"kb_dynamic_result.tool_trace exists={kb_dynamic_result.tool_trace if kb_dynamic_result else None}"
)
if kb_dynamic_result and kb_dynamic_result.tool_trace:
kb_trace = kb_dynamic_result.tool_trace
logger.info(
f"[TRACE-MERGE] KB trace arguments: {kb_trace.arguments}"
)
for i, tc in enumerate(final_tool_calls):
logger.info(
f"[TRACE-MERGE] Checking tool_call[{i}]: tool_name={tc.tool_name}"
)
if tc.tool_name == "kb_search_dynamic":
logger.info(
f"[TRACE-MERGE] Replacing trace at index {i}: old_args={tc.arguments}, new_args={kb_trace.arguments}"
)
final_tool_calls[i] = kb_trace
break
else:
logger.info(
f"[TRACE-MERGE] Skipped merge: kb_dynamic_result={kb_dynamic_result is not None}, "
f"tool_trace={kb_dynamic_result.tool_trace if kb_dynamic_result else 'N/A'}"
)
trace_logger.update_trace(
request_id=request_id,
react_iterations=react_ctx.iteration,
tool_calls=final_tool_calls,
)
segments = _text_to_segments(final_answer)
return DialogueResponse(
segments=segments,
trace=TraceInfo(
mode=ExecutionMode.AGENT,
request_id=trace.request_id,
generation_id=trace.generation_id,
react_iterations=react_ctx.iteration,
tools_used=[tc.tool_name for tc in final_tool_calls] if final_tool_calls else None,
tool_calls=final_tool_calls,
timeout_profile=timeout_governor.profile,
kb_tool_called=True,
kb_hit=kb_success and len(kb_hits) > 0,
fallback_reason_code=kb_fallback_reason,
# [AC-SCENE-SLOT-02] 场景槽位追踪
scene=request.scene,
scene_slot_context=base_context.get("scene_slot_context"),
missing_slots=[s.get("slot_key") for s in kb_missing_slots] if kb_missing_slots else None,
ask_back_triggered=False,
slot_sources=slot_state.slot_sources if slot_state else None,
kb_filter_sources=kb_dynamic_result.filter_sources if kb_dynamic_result else None,
# [Step-KB-Binding] 步骤知识库绑定追踪
step_kb_binding=kb_dynamic_result.step_kb_binding if kb_dynamic_result else step_kb_binding_trace,
),
)
async def _generate_ask_back_for_missing_slots(
slot_state: Any,
missing_slots: list[dict[str, str]],
session: AsyncSession,
tenant_id: str,
session_id: str | None = None,
max_ask_back_slots: int = 2,
scene_slot_context: Any = None, # [AC-SCENE-SLOT-02] 场景槽位上下文
) -> str:
"""
[AC-MRS-SLOT-META-03, AC-MRS-SLOT-ASKBACK-01] 为缺失的必填槽位生成追问响应
[AC-SCENE-SLOT-02] 支持场景槽位包配置的追问策略
支持批量追问多个缺失槽位,优先追问必填槽位
"""
if not missing_slots:
return "请提供更多信息以便我更好地帮助您。"
# [AC-SCENE-SLOT-02] 如果有场景槽位上下文,使用场景配置的追问策略
if scene_slot_context:
ask_back_order = getattr(scene_slot_context, 'ask_back_order', 'priority')
if ask_back_order == "parallel":
prompts = []
for missing in missing_slots[:max_ask_back_slots]:
ask_back_prompt = missing.get("ask_back_prompt")
if ask_back_prompt:
prompts.append(ask_back_prompt)
else:
slot_key = missing.get("slot_key", "相关信息")
prompts.append(f"您的{slot_key}")
if len(prompts) == 1:
return prompts[0]
elif len(prompts) == 2:
return f"为了更好地为您服务,请告诉我{prompts[0]}{prompts[1]}"
else:
all_but_last = "".join(prompts[:-1])
return f"为了更好地为您服务,请告诉我{all_but_last},以及{prompts[-1]}"
if len(missing_slots) == 1:
first_missing = missing_slots[0]
ask_back_prompt = first_missing.get("ask_back_prompt")
if ask_back_prompt:
return ask_back_prompt
slot_key = first_missing.get("slot_key", "相关信息")
label = first_missing.get("label", slot_key)
return f"为了更好地为您提供帮助,请告诉我您的{label}"
try:
from app.services.mid.batch_ask_back_service import (
BatchAskBackConfig,
BatchAskBackService,
)
config = BatchAskBackConfig(
max_ask_back_slots_per_turn=max_ask_back_slots,
prefer_required=True,
merge_prompts=True,
)
ask_back_service = BatchAskBackService(
session=session,
tenant_id=tenant_id,
session_id=session_id or "",
config=config,
)
result = await ask_back_service.generate_batch_ask_back(
missing_slots=missing_slots,
)
if result.has_ask_back():
return result.get_prompt()
except Exception as e:
logger.warning(f"[AC-MRS-SLOT-ASKBACK-01] Batch ask-back failed, fallback to single: {e}")
prompts = []
for missing in missing_slots[:max_ask_back_slots]:
ask_back_prompt = missing.get("ask_back_prompt")
if ask_back_prompt:
prompts.append(ask_back_prompt)
else:
label = missing.get("label", missing.get("slot_key", "相关信息"))
prompts.append(f"您的{label}")
if len(prompts) == 1:
return prompts[0]
elif len(prompts) == 2:
return f"为了更好地为您服务,请告诉我{prompts[0]}{prompts[1]}"
else:
all_but_last = "".join(prompts[:-1])
return f"为了更好地为您服务,请告诉我{all_but_last},以及{prompts[-1]}"
async def _execute_micro_flow_mode(
tenant_id: str,
request: DialogueRequest,
router_result: Any,
trace: TraceInfo,
trace_logger: TraceLogger,
session: AsyncSession,
start_time: float,
) -> DialogueResponse:
"""Execute micro flow mode."""
if router_result.target_flow_id:
try:
from app.services.flow.engine import FlowEngine
flow_engine = FlowEngine(session)
instance, first_step = await flow_engine.start(
tenant_id=tenant_id,
session_id=request.session_id,
flow_id=router_result.target_flow_id,
)
if first_step:
return DialogueResponse(
segments=_text_to_segments(first_step),
trace=TraceInfo(
mode=ExecutionMode.MICRO_FLOW,
request_id=trace.request_id,
generation_id=trace.generation_id,
intent=router_result.intent,
),
)
except Exception as e:
logger.warning(f"[AC-IDMP-05] Micro flow start failed: {e}")
return DialogueResponse(
segments=[Segment(
text="正在为您处理,请稍候...",
delay_after=0,
)],
trace=TraceInfo(
mode=ExecutionMode.MICRO_FLOW,
request_id=trace.request_id,
generation_id=trace.generation_id,
intent=router_result.intent,
fallback_reason_code=router_result.fallback_reason_code,
),
)
async def _execute_fixed_mode(
tenant_id: str,
request: DialogueRequest,
router_result: Any,
trace: TraceInfo,
trace_logger: TraceLogger,
start_time: float,
) -> DialogueResponse:
"""Execute fixed reply mode."""
text = router_result.fixed_reply or "收到您的消息,我们会尽快处理。"
return DialogueResponse(
segments=_text_to_segments(text),
trace=TraceInfo(
mode=ExecutionMode.FIXED,
request_id=trace.request_id,
generation_id=trace.generation_id,
intent=router_result.intent,
fallback_reason_code=router_result.fallback_reason_code,
),
)
async def _execute_transfer_mode(
tenant_id: str,
request: DialogueRequest,
router_result: Any,
trace: TraceInfo,
trace_logger: TraceLogger,
start_time: float,
) -> DialogueResponse:
"""Execute transfer to human mode."""
text = router_result.transfer_message or "正在为您转接人工客服,请稍候..."
return DialogueResponse(
segments=_text_to_segments(text),
trace=TraceInfo(
mode=ExecutionMode.TRANSFER,
request_id=trace.request_id,
generation_id=trace.generation_id,
intent=router_result.intent,
fallback_reason_code=router_result.fallback_reason_code,
),
)
def _text_to_segments(text: str) -> list[Segment]:
"""Convert text to segments."""
paragraphs = [p.strip() for p in text.split("\n\n") if p.strip()]
if not paragraphs:
paragraphs = [text]
return [
Segment(text=p, delay_after=100 if i < len(paragraphs) - 1 else 0)
for i, p in enumerate(paragraphs)
]