From 248a22543612fb900ef9649e8af249ff414886ff Mon Sep 17 00:00:00 2001 From: MerCry Date: Tue, 10 Mar 2026 12:06:57 +0800 Subject: [PATCH] feat: implement mid-platform dialogue and session management with memory recall and KB search tools [AC-IDMP-01~20] --- ai-service/app/api/mid/dialogue.py | 608 +++++++++++++++++- ai-service/app/api/mid/sessions.py | 77 +++ ai-service/app/models/mid/__init__.py | 2 + ai-service/app/models/mid/schemas.py | 16 +- ai-service/app/services/mid/__init__.py | 20 + .../services/mid/kb_search_dynamic_tool.py | 550 ++++++++++++++-- .../app/services/mid/memory_recall_tool.py | 12 +- .../services/mid/metadata_filter_builder.py | 66 +- .../services/mid/role_based_field_provider.py | 4 + 9 files changed, 1277 insertions(+), 78 deletions(-) diff --git a/ai-service/app/api/mid/dialogue.py b/ai-service/app/api/mid/dialogue.py index 5c1163c..8cbc6a3 100644 --- a/ai-service/app/api/mid/dialogue.py +++ b/ai-service/app/api/mid/dialogue.py @@ -55,6 +55,18 @@ from app.services.mid.segment_humanizer import HumanizeConfig, SegmentHumanizer from app.services.mid.timeout_governor import TimeoutGovernor from app.services.mid.tool_registry import ToolRegistry from app.services.mid.trace_logger import TraceLogger +from app.services.prompt.template_service import PromptTemplateService +from app.services.prompt.variable_resolver import VariableResolver +from app.services.intent.clarification import ( + ClarificationEngine, + ClarifyReason, + ClarifySessionManager, + ClarifyState, + HybridIntentResult, + IntentCandidate, + T_HIGH, + T_LOW, +) logger = logging.getLogger(__name__) @@ -118,6 +130,7 @@ _kb_search_dynamic_registered: bool = False _intent_hint_registered: bool = False _high_risk_check_registered: bool = False _memory_recall_registered: bool = False +_metadata_discovery_registered: bool = False def ensure_kb_search_dynamic_registered( @@ -134,7 +147,7 @@ def ensure_kb_search_dynamic_registered( config = KbSearchDynamicConfig( enabled=True, top_k=5, - timeout_ms=2000, + timeout_ms=10000, min_score_threshold=0.5, ) @@ -222,6 +235,26 @@ def ensure_memory_recall_registered( logger.info("[AC-IDMP-13] memory_recall tool registered to registry") +def ensure_metadata_discovery_registered( + registry: ToolRegistry, + session: AsyncSession, +) -> None: + """Ensure metadata_discovery tool is registered.""" + global _metadata_discovery_registered + if _metadata_discovery_registered: + return + + from app.services.mid.metadata_discovery_tool import register_metadata_discovery_tool + + register_metadata_discovery_tool( + registry=registry, + session=session, + timeout_governor=get_timeout_governor(), + ) + _metadata_discovery_registered = True + logger.info("[MetadataDiscovery] metadata_discovery tool registered to registry") + + def get_output_guardrail_executor() -> OutputGuardrailExecutor: """Get or create OutputGuardrailExecutor instance.""" if "output_guardrail_executor" not in _mid_services: @@ -259,6 +292,16 @@ def get_runtime_observer() -> RuntimeObserver: return _mid_services["runtime_observer"] +def get_clarification_engine() -> ClarificationEngine: + """Get or create ClarificationEngine instance.""" + if "clarification_engine" not in _mid_services: + _mid_services["clarification_engine"] = ClarificationEngine( + t_high=T_HIGH, + t_low=T_LOW, + ) + return _mid_services["clarification_engine"] + + @router.post( "/dialogue/respond", operation_id="respondDialogue", @@ -288,6 +331,7 @@ async def respond_dialogue( default_kb_tool_runner: DefaultKbToolRunner = Depends(get_default_kb_tool_runner), segment_humanizer: SegmentHumanizer = Depends(get_segment_humanizer), runtime_observer: RuntimeObserver = Depends(get_runtime_observer), + clarification_engine: ClarificationEngine = Depends(get_clarification_engine), ) -> DialogueResponse: """ [AC-MARH-01~12] Generate dialogue response with segments and trace. @@ -316,7 +360,8 @@ async def respond_dialogue( logger.info( f"[AC-MARH-01] Dialogue request: tenant={tenant_id}, " - f"session={dialogue_request.session_id}, request_id={request_id}" + f"session={dialogue_request.session_id}, request_id={request_id}, " + f"user_message={dialogue_request.user_message[:100] if dialogue_request.user_message else 'None'}..." ) runtime_ctx = runtime_observer.start_observation( @@ -340,6 +385,7 @@ async def respond_dialogue( ensure_intent_hint_registered(tool_registry, session) ensure_high_risk_check_registered(tool_registry, session) ensure_memory_recall_registered(tool_registry, session) + ensure_metadata_discovery_registered(tool_registry, session) try: interrupt_ctx = interrupt_context_enricher.enrich( @@ -541,7 +587,15 @@ async def respond_dialogue( except Exception as e: latency_ms = int((time.time() - start_time) * 1000) - logger.error(f"[AC-IDMP-06] Dialogue error: {e}") + import traceback + logger.error( + f"[AC-IDMP-06] Dialogue error: {e}\n" + f"Exception type: {type(e).__name__}\n" + f"Request details: session_id={dialogue_request.session_id}, " + f"user_message={dialogue_request.user_message[:200] if dialogue_request.user_message else 'None'}, " + f"scene={dialogue_request.scene}, user_id={dialogue_request.user_id}\n" + f"Traceback:\n{traceback.format_exc()}" + ) trace_logger.update_trace( request_id=request_id, @@ -593,7 +647,7 @@ async def _match_intent( return IntentMatch( intent_id=str(result.rule.id), intent_name=result.rule.name, - confidence=0.8, + confidence=1.0, response_type=result.rule.response_type, target_kb_ids=result.rule.target_kb_ids, flow_id=str(result.rule.flow_id) if result.rule.flow_id else None, @@ -608,6 +662,227 @@ async def _match_intent( return None +async def _match_intent_hybrid( + tenant_id: str, + request: DialogueRequest, + session: AsyncSession, + clarification_engine: ClarificationEngine, +) -> HybridIntentResult: + """ + [AC-CLARIFY] Hybrid intent matching with clarification support. + + Returns HybridIntentResult with unified confidence and candidates. + """ + try: + from app.services.intent.router import IntentRouter + from app.services.intent.rule_service import IntentRuleService + + rule_service = IntentRuleService(session) + rules = await rule_service.get_enabled_rules_for_matching(tenant_id) + + if not rules: + return HybridIntentResult( + intent=None, + confidence=0.0, + candidates=[], + ) + + router = IntentRouter() + + try: + fusion_result = await router.match_hybrid( + message=request.user_message, + rules=rules, + tenant_id=tenant_id, + ) + + hybrid_result = HybridIntentResult.from_fusion_result(fusion_result) + + logger.info( + f"[AC-CLARIFY] Hybrid intent match: " + f"intent={hybrid_result.intent.intent_name if hybrid_result.intent else None}, " + f"confidence={hybrid_result.confidence:.3f}, " + f"need_clarify={hybrid_result.need_clarify}" + ) + + return hybrid_result + + except Exception as e: + logger.warning(f"[AC-CLARIFY] Hybrid match failed, fallback to rule: {e}") + + result = router.match(request.user_message, rules) + + if result: + confidence = clarification_engine.compute_confidence( + rule_score=1.0, + semantic_score=0.0, + llm_score=0.0, + ) + candidate = IntentCandidate( + intent_id=str(result.rule.id), + intent_name=result.rule.name, + confidence=confidence, + response_type=result.rule.response_type, + target_kb_ids=result.rule.target_kb_ids, + flow_id=str(result.rule.flow_id) if result.rule.flow_id else None, + fixed_reply=result.rule.fixed_reply, + transfer_message=result.rule.transfer_message, + ) + return HybridIntentResult( + intent=candidate, + confidence=confidence, + candidates=[candidate], + ) + + return HybridIntentResult( + intent=None, + confidence=0.0, + candidates=[], + ) + + except Exception as e: + logger.warning(f"[AC-CLARIFY] Intent match failed: {e}") + return HybridIntentResult( + intent=None, + confidence=0.0, + candidates=[], + ) + + +async def _handle_clarification( + tenant_id: str, + request: DialogueRequest, + hybrid_result: HybridIntentResult, + clarification_engine: ClarificationEngine, + trace: TraceInfo, + session: AsyncSession, + required_slots: list[str] | None = None, + filled_slots: dict[str, Any] | None = None, +) -> DialogueResponse | None: + """ + [AC-CLARIFY] Handle clarification logic. + + Returns DialogueResponse if clarification is needed, None otherwise. + """ + existing_state = ClarifySessionManager.get_session(request.session_id) + + if existing_state and not existing_state.is_max_retry(): + logger.info( + f"[AC-CLARIFY] Processing clarify response: " + f"session={request.session_id}, retry={existing_state.retry_count}" + ) + + new_result = clarification_engine.process_clarify_response( + user_message=request.user_message, + state=existing_state, + ) + + if not new_result.need_clarify: + ClarifySessionManager.clear_session(request.session_id) + + if new_result.intent: + return None + + clarify_prompt = clarification_engine.generate_clarify_prompt(existing_state) + + return DialogueResponse( + segments=[Segment(text=clarify_prompt, delay_after=0)], + trace=TraceInfo( + mode=ExecutionMode.FIXED, + request_id=trace.request_id, + generation_id=trace.generation_id, + fallback_reason_code=f"clarify_retry_{existing_state.retry_count}", + intent=existing_state.candidates[0].intent_name if existing_state.candidates else None, + ), + ) + + should_clarify, clarify_state = clarification_engine.should_trigger_clarify( + result=hybrid_result, + required_slots=required_slots, + filled_slots=filled_slots, + ) + + if not should_clarify or not clarify_state: + return None + + logger.info( + f"[AC-CLARIFY] Clarification triggered: " + f"session={request.session_id}, reason={clarify_state.reason}, " + f"confidence={hybrid_result.confidence:.3f}" + ) + + ClarifySessionManager.set_session(request.session_id, clarify_state) + + clarify_prompt = clarification_engine.generate_clarify_prompt(clarify_state) + + return DialogueResponse( + segments=[Segment(text=clarify_prompt, delay_after=0)], + trace=TraceInfo( + mode=ExecutionMode.FIXED, + request_id=trace.request_id, + generation_id=trace.generation_id, + fallback_reason_code=f"clarify_{clarify_state.reason.value}", + intent=hybrid_result.intent.intent_name if hybrid_result.intent else None, + ), + ) + + +async def _check_hard_block( + tenant_id: str, + request: DialogueRequest, + hybrid_result: HybridIntentResult, + clarification_engine: ClarificationEngine, + trace: TraceInfo, + required_slots: list[str] | None = None, + filled_slots: dict[str, Any] | None = None, +) -> DialogueResponse | None: + """ + [AC-CLARIFY] Check hard block conditions. + + Hard blocks: + 1. confidence < T_high: block entering new flow + 2. required_slots missing: block flow progression + + Returns DialogueResponse if blocked, None otherwise. + """ + is_blocked, block_reason = clarification_engine.check_hard_block( + result=hybrid_result, + required_slots=required_slots, + filled_slots=filled_slots, + ) + + if not is_blocked: + return None + + logger.info( + f"[AC-CLARIFY] Hard block triggered: " + f"session={request.session_id}, reason={block_reason}, " + f"confidence={hybrid_result.confidence:.3f}" + ) + + clarify_state = ClarifyState( + reason=block_reason, + asked_slot=required_slots[0] if required_slots and block_reason == ClarifyReason.MISSING_SLOT else None, + candidates=hybrid_result.candidates, + ) + + ClarifySessionManager.set_session(request.session_id, clarify_state) + clarification_engine._metrics.record_clarify_trigger() + + clarify_prompt = clarification_engine.generate_clarify_prompt(clarify_state) + + return DialogueResponse( + segments=[Segment(text=clarify_prompt, delay_after=0)], + trace=TraceInfo( + mode=ExecutionMode.FIXED, + request_id=trace.request_id, + generation_id=trace.generation_id, + fallback_reason_code=f"hard_block_{block_reason.value}", + intent=hybrid_result.intent.intent_name if hybrid_result.intent else None, + ), + ) + + async def _handle_legacy_response( tenant_id: str, request: DialogueRequest, @@ -793,18 +1068,32 @@ async def _execute_agent_mode( session: AsyncSession | None = None, tool_registry: ToolRegistry | None = None, ) -> DialogueResponse: - """[AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-IDMP-13] Execute agent mode with ReAct loop, KB tool, and memory recall.""" + """ + [AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-IDMP-13, AC-MRS-SLOT-META-03] + Execute agent mode with ReAct loop, KB tool, memory recall, and slot state aggregation. + """ from app.services.llm.factory import get_llm_config_manager + from app.services.mid.slot_state_aggregator import SlotStateAggregator + + logger.info( + f"[DEBUG-AGENT] Starting _execute_agent_mode: tenant={tenant_id}, " + f"session={request.session_id}, scene={request.scene}, user_id={request.user_id}, " + f"user_message={request.user_message[:100] if request.user_message else 'None'}..." + ) try: llm_manager = get_llm_config_manager() llm_client = llm_manager.get_client() + logger.info(f"[DEBUG-AGENT] LLM client obtained successfully") except Exception as e: logger.warning(f"[AC-MARH-07] Failed to get LLM client: {e}") llm_client = None base_context = {"history": [h.model_dump() for h in request.history]} if request.history else {} + if request.scene: + base_context["scene"] = request.scene + if interrupt_ctx and interrupt_ctx.consumed: base_context["interrupted_content"] = interrupt_ctx.interrupted_content base_context["interrupted_segment_ids"] = interrupt_ctx.interrupted_segment_ids @@ -813,8 +1102,50 @@ async def _execute_agent_mode( f"{len(interrupt_ctx.interrupted_content or '')} chars" ) + # [AC-MRS-SLOT-META-03] 初始化槽位状态聚合器 + slot_state = None memory_context = "" memory_missing_slots: list[str] = [] + scene_slot_context = None + + flow_status = None + is_flow_runtime = False + if session: + from app.services.flow.engine import FlowEngine + flow_engine = FlowEngine(session) + flow_status = await flow_engine.get_flow_status(tenant_id, request.session_id) + is_flow_runtime = bool(flow_status and flow_status.get("status") == "active") + + logger.info( + f"[AC-IDMP-02] 运行轨道判定: session={request.session_id}, " + f"is_flow_runtime={is_flow_runtime}, flow_id={flow_status.get('flow_id') if flow_status else None}, " + f"current_step={flow_status.get('current_step') if flow_status else None}" + ) + + # [AC-SCENE-SLOT-02] 加载场景槽位包(仅流程态加载) + if is_flow_runtime and session and request.scene: + from app.services.mid.scene_slot_bundle_loader import SceneSlotBundleLoader + + scene_loader = SceneSlotBundleLoader(session) + scene_slot_context = await scene_loader.load_scene_context( + tenant_id=tenant_id, + scene_key=request.scene, + ) + + if scene_slot_context: + logger.info( + f"[AC-SCENE-SLOT-02] 场景槽位包加载成功: " + f"scene={request.scene}, required={len(scene_slot_context.required_slots)}, " + f"optional={len(scene_slot_context.optional_slots)}, " + f"threshold={scene_slot_context.completion_threshold}" + ) + base_context["scene_slot_context"] = { + "scene_key": scene_slot_context.scene_key, + "scene_name": scene_slot_context.scene_name, + "required_slots": scene_slot_context.get_required_slot_keys(), + "optional_slots": scene_slot_context.get_optional_slot_keys(), + } + if session and request.user_id: memory_recall_tool = MemoryRecallTool( session=session, @@ -855,11 +1186,111 @@ async def _execute_agent_mode( f"[AC-IDMP-13] Memory recall fallback: reason={memory_result.fallback_reason_code}" ) + # [AC-MRS-SLOT-META-03] 仅流程态执行槽位聚合/提取 + if is_flow_runtime: + slot_aggregator = SlotStateAggregator( + session=session, + tenant_id=tenant_id, + session_id=request.session_id, + ) + slot_state = await slot_aggregator.aggregate( + memory_slots=memory_result.slots, + current_input_slots=None, # 可从 request 中解析 + context=base_context, + scene_slot_context=scene_slot_context, # [AC-SCENE-SLOT-02] 传入场景槽位上下文 + ) + + logger.info( + f"[AC-MRS-SLOT-META-03] 流程态槽位聚合完成: " + f"filled={len(slot_state.filled_slots)}, " + f"missing={len(slot_state.missing_required_slots)}, " + f"mappings={slot_state.slot_to_field_map}" + ) + + # [AC-MRS-SLOT-EXTRACT-01] 自动提取槽位 + if slot_state and slot_state.missing_required_slots: + from app.services.mid.slot_extraction_integration import SlotExtractionIntegration + + extraction_integration = SlotExtractionIntegration( + session=session, + tenant_id=tenant_id, + session_id=request.session_id, + ) + + extraction_result = await extraction_integration.extract_and_fill( + user_input=request.user_message, + slot_state=slot_state, + ) + + if extraction_result.extracted_slots: + slot_state = await slot_aggregator.aggregate( + memory_slots=memory_result.slots, + current_input_slots=extraction_result.extracted_slots, + context=base_context, + ) + + logger.info( + f"[AC-MRS-SLOT-EXTRACT-01] 流程态自动提槽完成: " + f"extracted={list(extraction_result.extracted_slots.keys())}, " + f"time_ms={extraction_result.total_execution_time_ms:.2f}" + ) + else: + logger.info( + f"[AC-MRS-SLOT-META-03] 当前为通用问答轨道,跳过槽位聚合与缺槽追问: " + f"session={request.session_id}" + ) + kb_hits = [] kb_success = False kb_fallback_reason = None kb_applied_filter = {} kb_missing_slots = [] + kb_dynamic_result = None + step_kb_binding_trace: dict[str, Any] | None = None + + # [Step-KB-Binding] 获取当前流程步骤的 KB 配置(仅流程态) + step_kb_config = None + if is_flow_runtime and session and flow_status: + current_step_no = flow_status.get("current_step") + flow_id = flow_status.get("flow_id") + + if flow_id and current_step_no: + from app.models.entities import ScriptFlow + from sqlalchemy import select + + stmt = select(ScriptFlow).where(ScriptFlow.id == flow_id) + result = await session.execute(stmt) + flow = result.scalar_one_or_none() + + if flow and flow.steps: + step_idx = current_step_no - 1 + if 0 <= step_idx < len(flow.steps): + current_step = flow.steps[step_idx] + + # 构建 StepKbConfig + from app.services.mid.kb_search_dynamic_tool import StepKbConfig + step_kb_config = StepKbConfig( + allowed_kb_ids=current_step.get("allowed_kb_ids"), + preferred_kb_ids=current_step.get("preferred_kb_ids"), + kb_query_hint=current_step.get("kb_query_hint"), + max_kb_calls=current_step.get("max_kb_calls_per_step", 1), + step_id=f"{flow_id}_step_{current_step_no}", + ) + + step_kb_binding_trace = { + "flow_id": str(flow_id), + "flow_name": flow_status.get("flow_name"), + "current_step": current_step_no, + "step_id": step_kb_config.step_id, + "allowed_kb_ids": step_kb_config.allowed_kb_ids, + "preferred_kb_ids": step_kb_config.preferred_kb_ids, + } + + logger.info( + f"[Step-KB-Binding] 步骤知识库配置加载成功: " + f"flow={flow_status.get('flow_name')}, step={current_step_no}, " + f"allowed_kb_ids={step_kb_config.allowed_kb_ids}" + ) if session and tool_registry: kb_tool = KbSearchDynamicTool( @@ -868,17 +1299,21 @@ async def _execute_agent_mode( config=KbSearchDynamicConfig( enabled=True, top_k=5, - timeout_ms=2000, + timeout_ms=10000, min_score_threshold=0.5, ), ) + # [AC-MRS-SLOT-META-03] 传入 slot_state 进行 KB 检索 + # [Step-KB-Binding] 传入 step_kb_config 进行步骤级别的 KB 约束 kb_dynamic_result = await kb_tool.execute( query=request.user_message, tenant_id=tenant_id, - scene="open_consult", top_k=5, context=base_context, + slot_state=slot_state, + step_kb_config=step_kb_config, + slot_policy="flow_strict" if is_flow_runtime else "agent_relaxed", ) kb_success = kb_dynamic_result.success @@ -894,10 +1329,44 @@ async def _execute_agent_mode( ) logger.info( - f"[AC-MARH-05] KB dynamic search: success={kb_success}, " + f"[AC-MARH-05] KB动态检索完成: success={kb_success}, " f"hits={len(kb_hits)}, filter={kb_applied_filter}, " - f"missing_slots={kb_missing_slots}" + f"missing_slots={kb_missing_slots}, track={'flow' if is_flow_runtime else 'agent'}" ) + + # [AC-MRS-SLOT-META-03] 处理缺失必填槽位 -> 追问闭环(仅流程态) + if is_flow_runtime and kb_fallback_reason == "MISSING_REQUIRED_SLOTS" and kb_missing_slots: + ask_back_text = await _generate_ask_back_for_missing_slots( + slot_state=slot_state, + missing_slots=kb_missing_slots, + session=session, + tenant_id=tenant_id, + session_id=request.session_id, + scene_slot_context=scene_slot_context, # [AC-SCENE-SLOT-02] 传入场景槽位上下文 + ) + + logger.info( + f"[AC-MRS-SLOT-META-03] 流程态缺槽追问文案已生成: " + f"{ask_back_text[:50]}..." + ) + + return DialogueResponse( + segments=[Segment(text=ask_back_text, delay_after=0)], + trace=TraceInfo( + mode=ExecutionMode.AGENT, + request_id=trace.request_id, + generation_id=trace.generation_id, + fallback_reason_code="missing_required_slots", + kb_tool_called=True, + kb_hit=False, + # [AC-SCENE-SLOT-02] 场景槽位追踪 + scene=request.scene, + scene_slot_context=base_context.get("scene_slot_context"), + missing_slots=[s.get("slot_key") for s in kb_missing_slots] if kb_missing_slots else None, + ask_back_triggered=True, + slot_sources=slot_state.slot_sources if slot_state else None, + ), + ) else: kb_result = await default_kb_tool_runner.execute( tenant_id=tenant_id, @@ -933,7 +1402,18 @@ async def _execute_agent_mode( timeout_governor=timeout_governor, llm_client=llm_client, tool_registry=tool_registry, + template_service=PromptTemplateService, + variable_resolver=VariableResolver(), tenant_id=tenant_id, + user_id=request.user_id, + session_id=request.session_id, + scene=request.scene, + ) + + logger.info( + f"[DEBUG-AGENT] Calling orchestrator.execute with: " + f"user_message={request.user_message[:100] if request.user_message else 'None'}, " + f"context_keys={list(base_context.keys())}, llm_client={llm_client is not None}" ) final_answer, react_ctx, agent_trace = await orchestrator.execute( @@ -941,6 +1421,12 @@ async def _execute_agent_mode( context=base_context, ) + logger.info( + f"[DEBUG-AGENT] orchestrator.execute completed: " + f"final_answer={final_answer[:200] if final_answer else 'None'}, " + f"iterations={react_ctx.iteration}, tool_calls={len(react_ctx.tool_calls) if react_ctx.tool_calls else 0}" + ) + runtime_observer.record_react(request_id, react_ctx.iteration, react_ctx.tool_calls) trace_logger.update_trace( @@ -964,10 +1450,114 @@ async def _execute_agent_mode( kb_tool_called=True, kb_hit=kb_success and len(kb_hits) > 0, fallback_reason_code=kb_fallback_reason, + # [AC-SCENE-SLOT-02] 场景槽位追踪 + scene=request.scene, + scene_slot_context=base_context.get("scene_slot_context"), + missing_slots=[s.get("slot_key") for s in kb_missing_slots] if kb_missing_slots else None, + ask_back_triggered=False, + slot_sources=slot_state.slot_sources if slot_state else None, + kb_filter_sources=kb_dynamic_result.filter_sources if kb_dynamic_result else None, + # [Step-KB-Binding] 步骤知识库绑定追踪 + step_kb_binding=kb_dynamic_result.step_kb_binding if kb_dynamic_result else step_kb_binding_trace, ), ) +async def _generate_ask_back_for_missing_slots( + slot_state: Any, + missing_slots: list[dict[str, str]], + session: AsyncSession, + tenant_id: str, + session_id: str | None = None, + max_ask_back_slots: int = 2, + scene_slot_context: Any = None, # [AC-SCENE-SLOT-02] 场景槽位上下文 +) -> str: + """ + [AC-MRS-SLOT-META-03, AC-MRS-SLOT-ASKBACK-01] 为缺失的必填槽位生成追问响应 + [AC-SCENE-SLOT-02] 支持场景槽位包配置的追问策略 + + 支持批量追问多个缺失槽位,优先追问必填槽位 + """ + if not missing_slots: + return "请提供更多信息以便我更好地帮助您。" + + # [AC-SCENE-SLOT-02] 如果有场景槽位上下文,使用场景配置的追问策略 + if scene_slot_context: + ask_back_order = getattr(scene_slot_context, 'ask_back_order', 'priority') + + if ask_back_order == "parallel": + prompts = [] + for missing in missing_slots[:max_ask_back_slots]: + ask_back_prompt = missing.get("ask_back_prompt") + if ask_back_prompt: + prompts.append(ask_back_prompt) + else: + slot_key = missing.get("slot_key", "相关信息") + prompts.append(f"您的{slot_key}") + + if len(prompts) == 1: + return prompts[0] + elif len(prompts) == 2: + return f"为了更好地为您服务,请告诉我{prompts[0]}和{prompts[1]}。" + else: + all_but_last = "、".join(prompts[:-1]) + return f"为了更好地为您服务,请告诉我{all_but_last},以及{prompts[-1]}。" + + if len(missing_slots) == 1: + first_missing = missing_slots[0] + ask_back_prompt = first_missing.get("ask_back_prompt") + if ask_back_prompt: + return ask_back_prompt + slot_key = first_missing.get("slot_key", "相关信息") + label = first_missing.get("label", slot_key) + return f"为了更好地为您提供帮助,请告诉我您的{label}。" + + try: + from app.services.mid.batch_ask_back_service import ( + BatchAskBackConfig, + BatchAskBackService, + ) + + config = BatchAskBackConfig( + max_ask_back_slots_per_turn=max_ask_back_slots, + prefer_required=True, + merge_prompts=True, + ) + + ask_back_service = BatchAskBackService( + session=session, + tenant_id=tenant_id, + session_id=session_id or "", + config=config, + ) + + result = await ask_back_service.generate_batch_ask_back( + missing_slots=missing_slots, + ) + + if result.has_ask_back(): + return result.get_prompt() + except Exception as e: + logger.warning(f"[AC-MRS-SLOT-ASKBACK-01] Batch ask-back failed, fallback to single: {e}") + + prompts = [] + for missing in missing_slots[:max_ask_back_slots]: + ask_back_prompt = missing.get("ask_back_prompt") + if ask_back_prompt: + prompts.append(ask_back_prompt) + else: + label = missing.get("label", missing.get("slot_key", "相关信息")) + prompts.append(f"您的{label}") + + if len(prompts) == 1: + return prompts[0] + elif len(prompts) == 2: + return f"为了更好地为您服务,请告诉我{prompts[0]}和{prompts[1]}。" + else: + all_but_last = "、".join(prompts[:-1]) + return f"为了更好地为您服务,请告诉我{all_but_last},以及{prompts[-1]}。" + + async def _execute_micro_flow_mode( tenant_id: str, request: DialogueRequest, diff --git a/ai-service/app/api/mid/sessions.py b/ai-service/app/api/mid/sessions.py index c2d2007..610b5a4 100644 --- a/ai-service/app/api/mid/sessions.py +++ b/ai-service/app/api/mid/sessions.py @@ -8,6 +8,7 @@ from typing import Annotated from fastapi import APIRouter, Depends, Path from fastapi.responses import JSONResponse +from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncSession from app.core.database import get_session @@ -26,6 +27,82 @@ router = APIRouter(prefix="/mid", tags=["Mid Platform Sessions"]) _session_modes: dict[str, SessionMode] = {} +class CancelFlowResponse(BaseModel): + """Response for cancel flow operation.""" + + success: bool + message: str + session_id: str + + +@router.post( + "/sessions/{sessionId}/cancel-flow", + operation_id="cancelActiveFlow", + summary="Cancel active flow", + description=""" + Cancel the active flow for a session. + + Use this when you encounter "Session already has an active flow" error. + """, + responses={ + 200: {"description": "Flow cancelled successfully", "model": CancelFlowResponse}, + 404: {"description": "No active flow found"}, + }, +) +async def cancel_active_flow( + sessionId: Annotated[str, Path(description="Session ID")], + session: Annotated[AsyncSession, Depends(get_session)], +) -> CancelFlowResponse: + """ + Cancel the active flow for a session. + + This endpoint allows you to cancel any active flow instance + so that a new flow can be started. + """ + tenant_id = get_tenant_id() + + if not tenant_id: + from app.core.exceptions import MissingTenantIdException + raise MissingTenantIdException() + + logger.info( + f"[Cancel Flow] Cancelling active flow: tenant={tenant_id}, session={sessionId}" + ) + + try: + from app.services.flow.engine import FlowEngine + + flow_engine = FlowEngine(session) + cancelled = await flow_engine.cancel_flow( + tenant_id=tenant_id, + session_id=sessionId, + reason="User requested cancellation via API", + ) + + if cancelled: + logger.info(f"[Cancel Flow] Flow cancelled: session={sessionId}") + return CancelFlowResponse( + success=True, + message="Active flow cancelled successfully", + session_id=sessionId, + ) + else: + logger.info(f"[Cancel Flow] No active flow found: session={sessionId}") + return CancelFlowResponse( + success=True, + message="No active flow found for this session", + session_id=sessionId, + ) + + except Exception as e: + logger.error(f"[Cancel Flow] Failed to cancel flow: {e}") + return CancelFlowResponse( + success=False, + message=f"Failed to cancel flow: {str(e)}", + session_id=sessionId, + ) + + @router.post( "/sessions/{sessionId}/mode", operation_id="switchSessionMode", diff --git a/ai-service/app/models/mid/__init__.py b/ai-service/app/models/mid/__init__.py index 642077a..ae6393e 100644 --- a/ai-service/app/models/mid/__init__.py +++ b/ai-service/app/models/mid/__init__.py @@ -73,6 +73,7 @@ class DialogueRequest(BaseModel): history: list[HistoryMessage] = Field(default_factory=list, description="已送达历史") interrupted_segments: list[InterruptedSegment] | None = Field(default=None, description="打断的分段") feature_flags: FeatureFlags | None = Field(default=None, description="特性开关") + scene: str | None = Field(default=None, description="场景标识,用于KB过滤,如 'open_consult', 'after_sale'") class Segment(BaseModel): @@ -127,6 +128,7 @@ class TraceInfo(BaseModel): ) tools_used: list[str] | None = Field(default=None, description="使用的工具列表") tool_calls: list[ToolCallTraceModel] | None = Field(default=None, description="工具调用追踪") + step_kb_binding: dict[str, Any] | None = Field(default=None, description="[Step-KB-Binding] 步骤知识库绑定信息") class DialogueResponse(BaseModel): diff --git a/ai-service/app/models/mid/schemas.py b/ai-service/app/models/mid/schemas.py index beb4e41..03f3743 100644 --- a/ai-service/app/models/mid/schemas.py +++ b/ai-service/app/models/mid/schemas.py @@ -86,6 +86,7 @@ class DialogueRequest(BaseModel): humanize_config: HumanizeConfigRequest | None = Field( default=None, description="Humanize config for segment delay" ) + scene: str | None = Field(default=None, description="Scene identifier for KB filtering, e.g., 'open_consult', 'after_sale'") class Segment(BaseModel): @@ -122,6 +123,8 @@ class ToolCallTrace(BaseModel): error_code: str | None = Field(default=None, description="Error code if failed") args_digest: str | None = Field(default=None, description="Arguments digest for logging") result_digest: str | None = Field(default=None, description="Result digest for logging") + arguments: dict[str, Any] | None = Field(default=None, description="Full tool call arguments") + result: Any = Field(default=None, description="Full tool call result") class SegmentStats(BaseModel): @@ -133,7 +136,7 @@ class SegmentStats(BaseModel): class TraceInfo(BaseModel): """[AC-MARH-02, AC-MARH-03, AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-MARH-11, - AC-MARH-12, AC-MARH-18, AC-MARH-19, AC-MARH-20] Trace info for observability.""" + AC-MARH-12, AC-MARH-18, AC-MARH-19, AC-MARH-20, AC-SCENE-SLOT-02] Trace info for observability.""" mode: ExecutionMode = Field(..., description="Execution mode") intent: str | None = Field(default=None, description="Matched intent") request_id: str | None = Field( @@ -156,6 +159,17 @@ class TraceInfo(BaseModel): high_risk_policy_set: list[HighRiskScenario] | None = Field(default=None, description="Active high-risk policy set") tools_used: list[str] | None = Field(default=None, description="Tools used in this request") tool_calls: list[ToolCallTrace] | None = Field(default=None, description="Tool call traces") + duration_ms: int = Field(default=0, ge=0, description="Execution duration in milliseconds") + created_at: str | None = Field(default=None, description="Creation timestamp") + # [AC-SCENE-SLOT-02] 场景槽位追踪字段 + scene: str | None = Field(default=None, description="当前场景标识") + scene_slot_context: dict[str, Any] | None = Field(default=None, description="场景槽位上下文信息") + missing_slots: list[str] | None = Field(default=None, description="缺失的必填槽位列表") + ask_back_triggered: bool | None = Field(default=False, description="是否触发了追问") + slot_sources: dict[str, str] | None = Field(default=None, description="槽位值来源映射") + kb_filter_sources: dict[str, str] | None = Field(default=None, description="KB 过滤条件来源映射") + # [Step-KB-Binding] 步骤知识库绑定追踪 + step_kb_binding: dict[str, Any] | None = Field(default=None, description="步骤知识库绑定信息,包含 step_id, allowed_kb_ids, used_kb_ids 等") class DialogueResponse(BaseModel): diff --git a/ai-service/app/services/mid/__init__.py b/ai-service/app/services/mid/__init__.py index 41391aa..87976d6 100644 --- a/ai-service/app/services/mid/__init__.py +++ b/ai-service/app/services/mid/__init__.py @@ -17,6 +17,18 @@ from .memory_adapter import MemoryAdapter, UserMemory from .default_kb_tool_runner import DefaultKbToolRunner, KbToolResult, KbToolConfig, get_default_kb_tool_runner from .segment_humanizer import SegmentHumanizer, HumanizeConfig, LengthBucket, get_segment_humanizer from .runtime_observer import RuntimeObserver, RuntimeContext, get_runtime_observer +from .slot_validation_service import ( + SlotValidationService, + ValidationResult, + SlotValidationError, + BatchValidationResult, + SlotValidationErrorCode, +) +from .slot_manager import ( + SlotManager, + SlotWriteResult, + create_slot_manager, +) __all__ = [ "PolicyRouter", @@ -54,4 +66,12 @@ __all__ = [ "RuntimeObserver", "RuntimeContext", "get_runtime_observer", + "SlotValidationService", + "ValidationResult", + "SlotValidationError", + "BatchValidationResult", + "SlotValidationErrorCode", + "SlotManager", + "SlotWriteResult", + "create_slot_manager", ] diff --git a/ai-service/app/services/mid/kb_search_dynamic_tool.py b/ai-service/app/services/mid/kb_search_dynamic_tool.py index ec8542b..5acc588 100644 --- a/ai-service/app/services/mid/kb_search_dynamic_tool.py +++ b/ai-service/app/services/mid/kb_search_dynamic_tool.py @@ -17,15 +17,20 @@ import logging import time import uuid from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal from sqlalchemy.ext.asyncio import AsyncSession from app.models.mid.schemas import ToolCallStatus, ToolCallTrace, ToolType from app.services.mid.metadata_filter_builder import ( FilterBuildResult, + FilterFieldInfo, MetadataFilterBuilder, ) +from app.services.mid.slot_state_aggregator import ( + SlotState, + SlotStateAggregator, +) from app.services.mid.timeout_governor import TimeoutGovernor if TYPE_CHECKING: @@ -37,6 +42,9 @@ DEFAULT_TOP_K = 5 DEFAULT_TIMEOUT_MS = 2000 KB_SEARCH_DYNAMIC_TOOL_NAME = "kb_search_dynamic" +_TOOL_SCHEMA_CACHE: dict[str, tuple[float, dict[str, Any]]] = {} +_TOOL_SCHEMA_CACHE_TTL_SECONDS = 300 # 5 minutes + @dataclass class KbSearchDynamicResult: @@ -46,9 +54,21 @@ class KbSearchDynamicResult: applied_filter: dict[str, Any] = field(default_factory=dict) missing_required_slots: list[dict[str, str]] = field(default_factory=list) filter_debug: dict[str, Any] = field(default_factory=dict) + filter_sources: dict[str, str] = field(default_factory=dict) # [AC-SCENE-SLOT-02] 过滤条件来源 fallback_reason_code: str | None = None duration_ms: int = 0 tool_trace: ToolCallTrace | None = None + step_kb_binding: dict[str, Any] | None = None # [Step-KB-Binding] 步骤知识库绑定信息 + + +@dataclass +class StepKbConfig: + """[Step-KB-Binding] 步骤级别的知识库配置。""" + allowed_kb_ids: list[str] | None = None + preferred_kb_ids: list[str] | None = None + kb_query_hint: str | None = None + max_kb_calls: int = 1 + step_id: str | None = None @dataclass @@ -86,12 +106,14 @@ class KbSearchDynamicTool: session: AsyncSession, timeout_governor: TimeoutGovernor | None = None, config: KbSearchDynamicConfig | None = None, + slot_state_aggregator: SlotStateAggregator | None = None, ): self._session = session self._timeout_governor = timeout_governor or TimeoutGovernor() self._config = config or KbSearchDynamicConfig() self._vector_retriever = None self._filter_builder: MetadataFilterBuilder | None = None + self._slot_state_aggregator = slot_state_aggregator @property def name(self) -> str: @@ -140,6 +162,128 @@ class KbSearchDynamicTool: }, } + async def get_dynamic_tool_schema(self, tenant_id: str) -> dict[str, Any]: + """ + 获取动态生成的工具 Schema,包含租户的元数据过滤字段。 + + 使用缓存机制,避免每次都查询数据库。 + 只显示关联知识库的元数据过滤字段(field_roles 包含 resource_filter)。 + + Args: + tenant_id: 租户 ID + + Returns: + 动态生成的工具 Schema + """ + import time + current_time = time.time() + + cache_key = f"tool_schema:{tenant_id}" + if cache_key in _TOOL_SCHEMA_CACHE: + cached_time, cached_schema = _TOOL_SCHEMA_CACHE[cache_key] + if current_time - cached_time < _TOOL_SCHEMA_CACHE_TTL_SECONDS: + logger.debug(f"[AC-MARH-05] Tool schema cache hit for tenant={tenant_id}") + return cached_schema + + logger.info(f"[AC-MARH-05] Building dynamic tool schema for tenant={tenant_id}") + + base_properties = { + "query": { + "type": "string", + "description": "检索查询文本", + }, + "top_k": { + "type": "integer", + "description": "返回结果数量", + "default": DEFAULT_TOP_K, + }, + } + + required_fields = ["query"] + context_properties = {} + + try: + if self._filter_builder is None: + self._filter_builder = MetadataFilterBuilder(self._session) + + filterable_fields = await self._filter_builder._get_filterable_fields(tenant_id) + + for field_info in filterable_fields: + field_schema = self._build_field_schema(field_info) + context_properties[field_info.field_key] = field_schema + + if field_info.required: + required_fields.append(field_info.field_key) + + logger.info( + f"[AC-MARH-05] Dynamic schema built: tenant={tenant_id}, " + f"context_fields={len(context_properties)}, required={required_fields}" + ) + except Exception as e: + logger.warning(f"[AC-MARH-05] Failed to get filterable fields: {e}, using base schema") + + if context_properties: + base_properties["context"] = { + "type": "object", + "description": "过滤条件,根据用户意图选择合适的字段传递", + "properties": context_properties, + } + else: + base_properties["context"] = { + "type": "object", + "description": "过滤条件(当前租户未配置元数据字段)", + } + + schema = { + "name": self.name, + "description": self.description, + "parameters": { + "type": "object", + "properties": base_properties, + "required": required_fields, + }, + } + + _TOOL_SCHEMA_CACHE[cache_key] = (current_time, schema) + + return schema + + def _build_field_schema(self, field_info: "FilterFieldInfo") -> dict[str, Any]: + """ + 根据字段信息构建 JSON Schema。 + + Args: + field_info: 字段信息 + + Returns: + 字段的 JSON Schema + """ + schema: dict[str, Any] = { + "description": field_info.label or field_info.field_key, + } + + if field_info.required: + schema["description"] += "(必填)" + + field_type = field_info.field_type.lower() if field_info.field_type else "string" + + if field_type in ("enum", "select", "array_enum", "multi_select"): + schema["type"] = "string" + if field_info.options: + schema["enum"] = field_info.options + schema["description"] += f",可选值:{', '.join(field_info.options)}" + elif field_type in ("number", "integer", "float"): + schema["type"] = "number" + elif field_type == "boolean": + schema["type"] = "boolean" + else: + schema["type"] = "string" + + if field_info.default_value is not None: + schema["default"] = field_info.default_value + + return schema + async def execute( self, query: str, @@ -147,16 +291,24 @@ class KbSearchDynamicTool: scene: str = "open_consult", top_k: int | None = None, context: dict[str, Any] | None = None, + slot_state: SlotState | None = None, + step_kb_config: StepKbConfig | None = None, + slot_policy: Literal["flow_strict", "agent_relaxed"] = "flow_strict", ) -> KbSearchDynamicResult: """ [AC-MARH-05] 执行 KB 动态检索。 + [AC-MRS-SLOT-META-02] 支持槽位状态聚合和过滤构建优先级 + [Step-KB-Binding] 支持步骤级别的知识库约束 Args: query: 检索查询 tenant_id: 租户 ID - scene: 场景标识 + scene: 场景标识(默认值,会被 context 中的 scene 覆盖) top_k: 返回数量 - context: 上下文(包含动态过滤值) + context: 上下文(包含动态过滤值,包括 scene) + slot_state: 预聚合的槽位状态(可选,优先使用) + step_kb_config: 步骤级别的知识库配置(可选) + slot_policy: 槽位策略(flow_strict=流程严格模式,agent_relaxed=通用问答宽松模式) Returns: KbSearchDynamicResult 包含检索结果和追踪信息 @@ -171,82 +323,150 @@ class KbSearchDynamicTool: start_time = time.time() top_k = top_k or self._config.top_k - logger.info( - f"[AC-MARH-05] Starting KB dynamic search: tenant={tenant_id}, " - f"query={query[:50]}..., scene={scene}, top_k={top_k}" - ) + effective_context = dict(context) if context else {} + effective_scene = effective_context.get("scene", scene) - filter_result: FilterBuildResult | None = None - - try: - if self._filter_builder is None: - self._filter_builder = MetadataFilterBuilder(self._session) - - filter_result = await self._filter_builder.build_filter( - tenant_id=tenant_id, - context=context, + # [Step-KB-Binding] 记录步骤知识库约束 + step_kb_binding_info: dict[str, Any] = {} + if step_kb_config: + step_kb_binding_info = { + "step_id": step_kb_config.step_id, + "allowed_kb_ids": step_kb_config.allowed_kb_ids, + "preferred_kb_ids": step_kb_config.preferred_kb_ids, + "kb_query_hint": step_kb_config.kb_query_hint, + "max_kb_calls": step_kb_config.max_kb_calls, + } + logger.info( + f"[Step-KB-Binding] Step KB config applied: " + f"step_id={step_kb_config.step_id}, " + f"allowed_kb_ids={step_kb_config.allowed_kb_ids}, " + f"preferred_kb_ids={step_kb_config.preferred_kb_ids}" ) - if filter_result.missing_required_slots: - logger.warning( - f"[AC-MARH-05] Missing required slots: " - f"{filter_result.missing_required_slots}" + logger.info( + f"[AC-MARH-05] 开始执行KB动态检索: tenant={tenant_id}, " + f"query={query[:50]}..., scene={effective_scene}, top_k={top_k}, " + f"slot_policy={slot_policy}, context_keys={list(effective_context.keys())}" + ) + + # [AC-MRS-SLOT-META-02] 如果提供了 slot_state,优先使用 + if slot_state is not None: + logger.info( + f"[AC-MRS-SLOT-META-02] Using provided slot_state: " + f"filled={len(slot_state.filled_slots)}, " + f"missing={len(slot_state.missing_required_slots)}" + ) + + # 检查是否有缺失的必填槽位(仅在流程严格模式下阻断) + if slot_state.missing_required_slots: + if slot_policy == "flow_strict": + duration_ms = int((time.time() - start_time) * 1000) + + logger.info( + f"[AC-MRS-SLOT-META-03] 流程严格模式命中缺失必填槽位,触发追问: " + f"tenant={tenant_id}, missing={len(slot_state.missing_required_slots)}" + ) + + tool_trace = ToolCallTrace( + tool_name=self.name, + tool_type=ToolType.INTERNAL, + duration_ms=duration_ms, + status=ToolCallStatus.ERROR, + error_code="MISSING_REQUIRED_SLOTS", + args_digest=f"query={query[:50]}, scene={effective_scene}", + result_digest=f"missing={len(slot_state.missing_required_slots)}", + arguments={"query": query, "scene": effective_scene, "context": context}, + result={"missing_required_slots": slot_state.missing_required_slots}, + ) + + return KbSearchDynamicResult( + success=False, + applied_filter={}, + missing_required_slots=slot_state.missing_required_slots, + filter_debug={ + "source": "slot_state", + "filled_slots": slot_state.filled_slots, + "slot_to_field_map": slot_state.slot_to_field_map, + }, + fallback_reason_code="MISSING_REQUIRED_SLOTS", + duration_ms=duration_ms, + tool_trace=tool_trace, + ) + + logger.info( + f"[AC-MRS-SLOT-META-03] 通用问答宽松模式检测到缺失槽位但不阻断检索: " + f"tenant={tenant_id}, missing={len(slot_state.missing_required_slots)}" ) - duration_ms = int((time.time() - start_time) * 1000) - - tool_trace = ToolCallTrace( - tool_name=self.name, - tool_type=ToolType.INTERNAL, - duration_ms=duration_ms, - status=ToolCallStatus.ERROR, - error_code="MISSING_REQUIRED_SLOTS", - args_digest=f"query={query[:50]}, scene={scene}", - result_digest=f"missing={len(filter_result.missing_required_slots)}", - ) - - return KbSearchDynamicResult( - success=False, - applied_filter=filter_result.applied_filter, - missing_required_slots=filter_result.missing_required_slots, - filter_debug=filter_result.debug_info, - fallback_reason_code="MISSING_REQUIRED_SLOTS", - duration_ms=duration_ms, - tool_trace=tool_trace, - ) - - metadata_filter = filter_result.applied_filter if filter_result.success else None + + # 使用 slot_state 构建 filter + metadata_filter, filter_sources = await self._build_filter_from_slot_state( + tenant_id=tenant_id, + slot_state=slot_state, + context=effective_context, + scene_slot_context=effective_context.get("scene_slot_context"), # [AC-SCENE-SLOT-02] + ) + else: + # 原有逻辑:构建元数据 filter + # 如果 context 简单(只有键值对),直接构造 filter,跳过 MetadataFilterBuilder + metadata_filter = await self._build_filter_legacy( + tenant_id=tenant_id, + context=effective_context, + query=query, + effective_scene=effective_scene, + start_time=start_time, + ) + + if isinstance(metadata_filter, KbSearchDynamicResult): + # 有错误,直接返回 + return metadata_filter + try: hits = await self._retrieve_with_timeout( tenant_id=tenant_id, query=query, metadata_filter=metadata_filter, top_k=top_k, + step_kb_config=step_kb_config, ) duration_ms = int((time.time() - start_time) * 1000) kb_hit = len(hits) > 0 + # [Step-KB-Binding] 记录命中的知识库 + hit_kb_ids = list(set(hit.get("kb_id") for hit in hits if hit.get("kb_id"))) + if step_kb_binding_info: + step_kb_binding_info["used_kb_ids"] = hit_kb_ids + step_kb_binding_info["kb_hit"] = kb_hit + tool_trace = ToolCallTrace( tool_name=self.name, tool_type=ToolType.INTERNAL, duration_ms=duration_ms, status=ToolCallStatus.OK, - args_digest=f"query={query[:50]}, scene={scene}", + args_digest=f"query={query[:50]}, scene={effective_scene}", result_digest=f"hits={len(hits)}", + arguments={"query": query, "scene": effective_scene, "context": context}, + result={"hits_count": len(hits), "kb_hit": kb_hit}, ) logger.info( f"[AC-MARH-05] KB dynamic search completed: tenant={tenant_id}, " - f"hits={len(hits)}, duration_ms={duration_ms}, kb_hit={kb_hit}" + f"hits={len(hits)}, duration_ms={duration_ms}, kb_hit={kb_hit}, " + f"hit_kb_ids={hit_kb_ids}" ) + # 确定 filter 来源用于调试 + filter_source = "slot_state" if slot_state is not None else "builder" + kb_filter_sources = filter_sources if slot_state is not None else {} + return KbSearchDynamicResult( success=True, hits=hits, - applied_filter=filter_result.applied_filter if filter_result else {}, - filter_debug=filter_result.debug_info if filter_result else {}, + applied_filter=metadata_filter or {}, + filter_debug={"source": filter_source, "filter_sources": kb_filter_sources}, duration_ms=duration_ms, tool_trace=tool_trace, + step_kb_binding=step_kb_binding_info if step_kb_binding_info else None, ) except asyncio.TimeoutError: @@ -262,16 +482,18 @@ class KbSearchDynamicTool: duration_ms=duration_ms, status=ToolCallStatus.TIMEOUT, error_code="KB_TIMEOUT", + arguments={"query": query, "scene": effective_scene, "context": context}, ) return KbSearchDynamicResult( success=False, - applied_filter=filter_result.applied_filter if filter_result else {}, - missing_required_slots=filter_result.missing_required_slots if filter_result else [], - filter_debug=filter_result.debug_info if filter_result else {}, + applied_filter=metadata_filter or {}, + missing_required_slots=[], + filter_debug={"error": "timeout"}, fallback_reason_code="KB_TIMEOUT", duration_ms=duration_ms, tool_trace=tool_trace, + step_kb_binding=step_kb_binding_info if step_kb_binding_info else None, ) except Exception as e: @@ -287,31 +509,224 @@ class KbSearchDynamicTool: duration_ms=duration_ms, status=ToolCallStatus.ERROR, error_code="KB_ERROR", + arguments={"query": query, "scene": effective_scene, "context": context}, ) return KbSearchDynamicResult( success=False, - applied_filter=filter_result.applied_filter if filter_result else {}, - missing_required_slots=filter_result.missing_required_slots if filter_result else [], + applied_filter=metadata_filter or {}, + missing_required_slots=[], filter_debug={"error": str(e)}, fallback_reason_code="KB_ERROR", duration_ms=duration_ms, tool_trace=tool_trace, ) + async def _build_filter_legacy( + self, + tenant_id: str, + context: dict[str, Any], + query: str, + effective_scene: str, + start_time: float, + ) -> dict[str, Any] | KbSearchDynamicResult: + """ + [AC-MRS-SLOT-META-02] 原有逻辑:构建元数据 filter + + Returns: + dict: 构建成功的 filter + KbSearchDynamicResult: 构建失败时的错误结果 + """ + metadata_filter: dict[str, Any] | None = None + + # 简单 context:直接构造 filter(信任 AI 传入的值) + # 复杂场景:使用 MetadataFilterBuilder 进行严格验证 + is_simple_context = all( + isinstance(v, (str, int, float, bool)) + for v in context.values() + ) + + if is_simple_context: + # 直接构造 filter,不查询数据库 + metadata_filter = context + logger.info( + f"[AC-MARH-05] Using simple context as filter directly: " + f"{metadata_filter}" + ) + else: + # 复杂 context,使用 MetadataFilterBuilder + filter_result = await self._build_filter_with_builder( + tenant_id=tenant_id, + context=context, + query=query, + effective_scene=effective_scene, + start_time=start_time, + ) + if isinstance(filter_result, KbSearchDynamicResult): + # 有错误,直接返回 + return filter_result + metadata_filter = filter_result + + return metadata_filter or {} + + async def _build_filter_from_slot_state( + self, + tenant_id: str, + slot_state: SlotState, + context: dict[str, Any], + scene_slot_context: Any = None, # [AC-SCENE-SLOT-02] 场景槽位上下文 + ) -> tuple[dict[str, Any], dict[str, str]]: + """ + [AC-MRS-SLOT-META-02] 基于槽位状态构建过滤条件 + [AC-SCENE-SLOT-02] 支持场景槽位包配置的优先级 + + 过滤值来源优先级: + 1. 已确认槽位值(slot_state.filled_slots) + 2. 当前请求 context 显式值 + 3. 元数据默认值 + + Args: + tenant_id: 租户 ID + slot_state: 槽位状态 + context: 上下文 + scene_slot_context: 场景槽位上下文 + + Returns: + (过滤条件字典, 过滤来源映射) + """ + if self._filter_builder is None: + self._filter_builder = MetadataFilterBuilder(self._session) + + # 获取可过滤字段定义 + filterable_fields = await self._filter_builder._get_filterable_fields(tenant_id) + + applied_filter: dict[str, Any] = {} + filter_debug_sources: dict[str, str] = {} + + # [AC-SCENE-SLOT-02] 如果有场景槽位上下文,优先处理场景定义的槽位 + scene_slot_keys = set() + if scene_slot_context: + scene_slot_keys = set(scene_slot_context.get_all_slot_keys()) + logger.debug( + f"[AC-SCENE-SLOT-02] Processing scene slots: " + f"scene={scene_slot_context.scene_key}, slots={scene_slot_keys}" + ) + + for field_info in filterable_fields: + field_key = field_info.field_key + value = None + source = None + + # 优先级 1: 已确认槽位值(通过 slot_to_field_map 映射) + if slot_state.slot_to_field_map: + # 查找哪个 slot 映射到这个 field + for slot_key, mapped_field_key in slot_state.slot_to_field_map.items(): + if mapped_field_key == field_key and slot_key in slot_state.filled_slots: + value = slot_state.filled_slots[slot_key] + source = "slot" + break + + # 如果 slot 映射没有命中,直接检查 slot_key 是否等于 field_key + if value is None and field_key in slot_state.filled_slots: + value = slot_state.filled_slots[field_key] + source = "slot" + + # 优先级 2: 当前请求 context 显式值 + if value is None and field_key in context: + value = context[field_key] + source = "context" + + # 优先级 3: 元数据默认值 + if value is None and field_info.default_value is not None: + value = field_info.default_value + source = "default" + + # 构建过滤条件 + if value is not None: + filter_value = self._filter_builder._build_field_filter( + field_info, value + ) + if filter_value is not None: + applied_filter[field_key] = filter_value + filter_debug_sources[field_key] = source + + logger.info( + f"[AC-MRS-SLOT-META-02] Filter built from slot_state: " + f"fields={len(applied_filter)}, sources={filter_debug_sources}" + ) + + return applied_filter, filter_debug_sources + + async def _build_filter_with_builder( + self, + tenant_id: str, + context: dict[str, Any], + query: str, + effective_scene: str, + start_time: float, + ) -> dict[str, Any] | KbSearchDynamicResult: + """ + 使用 MetadataFilterBuilder 构建 filter(复杂场景)。 + + Returns: + dict: 构建成功的 filter + KbSearchDynamicResult: 构建失败时的错误结果 + """ + from app.services.mid.metadata_filter_builder import FilterBuildResult + + if self._filter_builder is None: + self._filter_builder = MetadataFilterBuilder(self._session) + + filter_result = await self._filter_builder.build_filter( + tenant_id=tenant_id, + context=context, + ) + + if filter_result.missing_required_slots: + logger.warning( + f"[AC-MARH-05] Missing required slots: " + f"{filter_result.missing_required_slots}" + ) + duration_ms = int((time.time() - start_time) * 1000) + + tool_trace = ToolCallTrace( + tool_name=self.name, + tool_type=ToolType.INTERNAL, + duration_ms=duration_ms, + status=ToolCallStatus.ERROR, + error_code="MISSING_REQUIRED_SLOTS", + args_digest=f"query={query[:50]}, scene={effective_scene}", + result_digest=f"missing={len(filter_result.missing_required_slots)}", + arguments={"query": query, "scene": effective_scene, "context": context}, + result={"missing_required_slots": filter_result.missing_required_slots}, + ) + + return KbSearchDynamicResult( + success=False, + applied_filter=filter_result.applied_filter, + missing_required_slots=filter_result.missing_required_slots, + filter_debug=filter_result.debug_info, + fallback_reason_code="MISSING_REQUIRED_SLOTS", + duration_ms=duration_ms, + tool_trace=tool_trace, + ) + + return filter_result.applied_filter if filter_result.success else {} + async def _retrieve_with_timeout( self, tenant_id: str, query: str, metadata_filter: dict[str, Any] | None = None, top_k: int = DEFAULT_TOP_K, + step_kb_config: StepKbConfig | None = None, ) -> list[dict[str, Any]]: """带超时控制的检索。""" timeout_seconds = self._config.timeout_ms / 1000.0 try: return await asyncio.wait_for( - self._do_retrieve(tenant_id, query, metadata_filter, top_k), + self._do_retrieve(tenant_id, query, metadata_filter, top_k, step_kb_config), timeout=timeout_seconds, ) except asyncio.TimeoutError: @@ -323,18 +738,36 @@ class KbSearchDynamicTool: query: str, metadata_filter: dict[str, Any] | None = None, top_k: int = DEFAULT_TOP_K, + step_kb_config: StepKbConfig | None = None, ) -> list[dict[str, Any]]: - """执行实际检索。""" + """执行实际检索。[Step-KB-Binding] 支持步骤级别的知识库约束。""" if self._vector_retriever is None: from app.services.retrieval.vector_retriever import get_vector_retriever self._vector_retriever = await get_vector_retriever() from app.services.retrieval.base import RetrievalContext + # [Step-KB-Binding] 确定要检索的知识库范围 + kb_ids = None + if step_kb_config: + # 如果配置了 allowed_kb_ids,则只检索这些知识库 + if step_kb_config.allowed_kb_ids: + kb_ids = step_kb_config.allowed_kb_ids + logger.info( + f"[Step-KB-Binding] Restricting KB search to: {kb_ids}" + ) + # 如果只配置了 preferred_kb_ids,优先检索这些知识库 + elif step_kb_config.preferred_kb_ids: + kb_ids = step_kb_config.preferred_kb_ids + logger.info( + f"[Step-KB-Binding] Preferring KB search in: {kb_ids}" + ) + ctx = RetrievalContext( tenant_id=tenant_id, query=query, - metadata=metadata_filter, + metadata_filter=metadata_filter, + kb_ids=kb_ids, ) result = await self._vector_retriever.retrieve(ctx) @@ -347,6 +780,7 @@ class KbSearchDynamicTool: "content": hit.text, "score": hit.score, "metadata": hit.metadata, + "kb_id": hit.metadata.get("kb_id"), }) return hits[:top_k] @@ -384,6 +818,7 @@ async def create_kb_search_dynamic_handler( scene: str = "open_consult", top_k: int = DEFAULT_TOP_K, context: dict[str, Any] | None = None, + **kwargs, # 接受系统注入的额外参数(user_id, session_id 等) ) -> dict[str, Any]: """ KB 动态检索 handler。 @@ -391,9 +826,9 @@ async def create_kb_search_dynamic_handler( Args: query: 检索查询 tenant_id: 租户 ID - scene: 场景标识 + scene: 场景标识(默认值,会被 context 中的 scene 覆盖) top_k: 返回数量 - context: 上下文 + context: 上下文(包含 scene 等过滤字段) Returns: 检索结果字典 @@ -442,6 +877,7 @@ def register_kb_search_dynamic_tool( scene: str = "open_consult", top_k: int = DEFAULT_TOP_K, context: dict[str, Any] | None = None, + **kwargs, # 接受系统注入的额外参数(user_id, session_id 等) ) -> dict[str, Any]: tool = KbSearchDynamicTool( session=session, diff --git a/ai-service/app/services/mid/memory_recall_tool.py b/ai-service/app/services/mid/memory_recall_tool.py index a932c64..97f8029 100644 --- a/ai-service/app/services/mid/memory_recall_tool.py +++ b/ai-service/app/services/mid/memory_recall_tool.py @@ -542,9 +542,9 @@ def register_memory_recall_tool( cfg = config or MemoryRecallConfig() async def memory_recall_handler( - tenant_id: str, - user_id: str, - session_id: str, + tenant_id: str = "", + user_id: str = "", + session_id: str = "", recall_scope: list[str] | None = None, max_recent_messages: int | None = None, ) -> dict[str, Any]: @@ -554,6 +554,7 @@ def register_memory_recall_tool( timeout_governor=timeout_governor, config=cfg, ) + result = await tool.execute( tenant_id=tenant_id, user_id=user_id, @@ -587,12 +588,9 @@ def register_memory_recall_tool( "recall_scope": {"type": "array", "description": "召回范围,例如 profile/facts/preferences/summary/slots"}, "max_recent_messages": {"type": "integer", "description": "历史回填窗口大小"} }, - "required": ["tenant_id", "user_id", "session_id"] + "required": [] }, "example_action_input": { - "tenant_id": "default", - "user_id": "u_10086", - "session_id": "s_abc_001", "recall_scope": ["profile", "facts", "preferences", "summary", "slots"], "max_recent_messages": 8 }, diff --git a/ai-service/app/services/mid/metadata_filter_builder.py b/ai-service/app/services/mid/metadata_filter_builder.py index 1e86678..ffa0848 100644 --- a/ai-service/app/services/mid/metadata_filter_builder.py +++ b/ai-service/app/services/mid/metadata_filter_builder.py @@ -165,21 +165,56 @@ class MetadataFilterBuilder: ) -> list[FilterFieldInfo]: """ [AC-MRS-11] 获取可过滤的字段定义。 + 优先从 Redis 缓存获取,未缓存则从数据库查询并缓存。 条件: - 状态=生效 (active) - field_roles 包含 resource_filter """ + import time + start_time = time.time() + + # 1. 尝试从缓存获取 + from app.services.metadata_cache_service import get_metadata_cache_service + cache_service = await get_metadata_cache_service() + cached_fields = await cache_service.get_fields(tenant_id) + + if cached_fields is not None: + # 缓存命中,直接返回 + logger.info( + f"[AC-MRS-11] Cache hit: Retrieved {len(cached_fields)} fields " + f"for tenant={tenant_id} in {(time.time() - start_time)*1000:.2f}ms" + ) + return [ + FilterFieldInfo( + field_key=f["field_key"], + label=f["label"], + field_type=f["field_type"], + required=f["required"], + options=f.get("options"), + default_value=f.get("default_value"), + is_filterable=f["is_filterable"], + ) + for f in cached_fields + ] + + # 2. 缓存未命中,从数据库查询 + logger.info(f"[AC-MRS-11] Cache miss: Querying database for tenant={tenant_id}") + db_start = time.time() + fields = await self._role_provider.get_fields_by_role( tenant_id=tenant_id, role=FieldRole.RESOURCE_FILTER.value, ) - + + db_time = (time.time() - db_start) * 1000 logger.info( - f"[AC-MRS-11] Retrieved {len(fields)} resource_filter fields for tenant={tenant_id}" + f"[AC-MRS-11] Retrieved {len(fields)} resource_filter fields " + f"for tenant={tenant_id} from DB in {db_time:.2f}ms" ) - - return [ + + # 3. 转换为 FilterFieldInfo 列表 + filter_fields = [ FilterFieldInfo( field_key=f.field_key, label=f.label, @@ -191,6 +226,29 @@ class MetadataFilterBuilder: ) for f in fields ] + + # 4. 缓存到 Redis + cache_data = [ + { + "field_key": f.field_key, + "label": f.label, + "field_type": f.field_type, + "required": f.required, + "options": f.options, + "default_value": f.default_value, + "is_filterable": f.is_filterable, + } + for f in filter_fields + ] + await cache_service.set_fields(tenant_id, cache_data) + + total_time = (time.time() - start_time) * 1000 + logger.info( + f"[AC-MRS-11] Total time for tenant={tenant_id}: {total_time:.2f}ms " + f"(DB: {db_time:.2f}ms)" + ) + + return filter_fields def _build_field_filter( self, diff --git a/ai-service/app/services/mid/role_based_field_provider.py b/ai-service/app/services/mid/role_based_field_provider.py index 858f7e7..e7c22d7 100644 --- a/ai-service/app/services/mid/role_based_field_provider.py +++ b/ai-service/app/services/mid/role_based_field_provider.py @@ -175,7 +175,9 @@ class RoleBasedFieldProvider: "slot_key": slot.slot_key, "type": slot.type, "required": slot.required, + # [AC-MRS-07-UPGRADE] 返回新旧字段 "extract_strategy": slot.extract_strategy, + "extract_strategies": slot.extract_strategies, "validation_rule": slot.validation_rule, "ask_back_prompt": slot.ask_back_prompt, "default_value": slot.default_value, @@ -217,7 +219,9 @@ class RoleBasedFieldProvider: "slot_key": field.field_key, "type": field.type, "required": field.required, + # [AC-MRS-07-UPGRADE] 返回新旧字段 "extract_strategy": None, + "extract_strategies": None, "validation_rule": None, "ask_back_prompt": None, "default_value": field.default_value,