feat: implement mid-platform dialogue and session management with memory recall and KB search tools [AC-IDMP-01~20]
This commit is contained in:
parent
d78b72ca93
commit
248a225436
|
|
@ -55,6 +55,18 @@ from app.services.mid.segment_humanizer import HumanizeConfig, SegmentHumanizer
|
|||
from app.services.mid.timeout_governor import TimeoutGovernor
|
||||
from app.services.mid.tool_registry import ToolRegistry
|
||||
from app.services.mid.trace_logger import TraceLogger
|
||||
from app.services.prompt.template_service import PromptTemplateService
|
||||
from app.services.prompt.variable_resolver import VariableResolver
|
||||
from app.services.intent.clarification import (
|
||||
ClarificationEngine,
|
||||
ClarifyReason,
|
||||
ClarifySessionManager,
|
||||
ClarifyState,
|
||||
HybridIntentResult,
|
||||
IntentCandidate,
|
||||
T_HIGH,
|
||||
T_LOW,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -118,6 +130,7 @@ _kb_search_dynamic_registered: bool = False
|
|||
_intent_hint_registered: bool = False
|
||||
_high_risk_check_registered: bool = False
|
||||
_memory_recall_registered: bool = False
|
||||
_metadata_discovery_registered: bool = False
|
||||
|
||||
|
||||
def ensure_kb_search_dynamic_registered(
|
||||
|
|
@ -134,7 +147,7 @@ def ensure_kb_search_dynamic_registered(
|
|||
config = KbSearchDynamicConfig(
|
||||
enabled=True,
|
||||
top_k=5,
|
||||
timeout_ms=2000,
|
||||
timeout_ms=10000,
|
||||
min_score_threshold=0.5,
|
||||
)
|
||||
|
||||
|
|
@ -222,6 +235,26 @@ def ensure_memory_recall_registered(
|
|||
logger.info("[AC-IDMP-13] memory_recall tool registered to registry")
|
||||
|
||||
|
||||
def ensure_metadata_discovery_registered(
|
||||
registry: ToolRegistry,
|
||||
session: AsyncSession,
|
||||
) -> None:
|
||||
"""Ensure metadata_discovery tool is registered."""
|
||||
global _metadata_discovery_registered
|
||||
if _metadata_discovery_registered:
|
||||
return
|
||||
|
||||
from app.services.mid.metadata_discovery_tool import register_metadata_discovery_tool
|
||||
|
||||
register_metadata_discovery_tool(
|
||||
registry=registry,
|
||||
session=session,
|
||||
timeout_governor=get_timeout_governor(),
|
||||
)
|
||||
_metadata_discovery_registered = True
|
||||
logger.info("[MetadataDiscovery] metadata_discovery tool registered to registry")
|
||||
|
||||
|
||||
def get_output_guardrail_executor() -> OutputGuardrailExecutor:
|
||||
"""Get or create OutputGuardrailExecutor instance."""
|
||||
if "output_guardrail_executor" not in _mid_services:
|
||||
|
|
@ -259,6 +292,16 @@ def get_runtime_observer() -> RuntimeObserver:
|
|||
return _mid_services["runtime_observer"]
|
||||
|
||||
|
||||
def get_clarification_engine() -> ClarificationEngine:
|
||||
"""Get or create ClarificationEngine instance."""
|
||||
if "clarification_engine" not in _mid_services:
|
||||
_mid_services["clarification_engine"] = ClarificationEngine(
|
||||
t_high=T_HIGH,
|
||||
t_low=T_LOW,
|
||||
)
|
||||
return _mid_services["clarification_engine"]
|
||||
|
||||
|
||||
@router.post(
|
||||
"/dialogue/respond",
|
||||
operation_id="respondDialogue",
|
||||
|
|
@ -288,6 +331,7 @@ async def respond_dialogue(
|
|||
default_kb_tool_runner: DefaultKbToolRunner = Depends(get_default_kb_tool_runner),
|
||||
segment_humanizer: SegmentHumanizer = Depends(get_segment_humanizer),
|
||||
runtime_observer: RuntimeObserver = Depends(get_runtime_observer),
|
||||
clarification_engine: ClarificationEngine = Depends(get_clarification_engine),
|
||||
) -> DialogueResponse:
|
||||
"""
|
||||
[AC-MARH-01~12] Generate dialogue response with segments and trace.
|
||||
|
|
@ -316,7 +360,8 @@ async def respond_dialogue(
|
|||
|
||||
logger.info(
|
||||
f"[AC-MARH-01] Dialogue request: tenant={tenant_id}, "
|
||||
f"session={dialogue_request.session_id}, request_id={request_id}"
|
||||
f"session={dialogue_request.session_id}, request_id={request_id}, "
|
||||
f"user_message={dialogue_request.user_message[:100] if dialogue_request.user_message else 'None'}..."
|
||||
)
|
||||
|
||||
runtime_ctx = runtime_observer.start_observation(
|
||||
|
|
@ -340,6 +385,7 @@ async def respond_dialogue(
|
|||
ensure_intent_hint_registered(tool_registry, session)
|
||||
ensure_high_risk_check_registered(tool_registry, session)
|
||||
ensure_memory_recall_registered(tool_registry, session)
|
||||
ensure_metadata_discovery_registered(tool_registry, session)
|
||||
|
||||
try:
|
||||
interrupt_ctx = interrupt_context_enricher.enrich(
|
||||
|
|
@ -541,7 +587,15 @@ async def respond_dialogue(
|
|||
|
||||
except Exception as e:
|
||||
latency_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(f"[AC-IDMP-06] Dialogue error: {e}")
|
||||
import traceback
|
||||
logger.error(
|
||||
f"[AC-IDMP-06] Dialogue error: {e}\n"
|
||||
f"Exception type: {type(e).__name__}\n"
|
||||
f"Request details: session_id={dialogue_request.session_id}, "
|
||||
f"user_message={dialogue_request.user_message[:200] if dialogue_request.user_message else 'None'}, "
|
||||
f"scene={dialogue_request.scene}, user_id={dialogue_request.user_id}\n"
|
||||
f"Traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
trace_logger.update_trace(
|
||||
request_id=request_id,
|
||||
|
|
@ -593,7 +647,7 @@ async def _match_intent(
|
|||
return IntentMatch(
|
||||
intent_id=str(result.rule.id),
|
||||
intent_name=result.rule.name,
|
||||
confidence=0.8,
|
||||
confidence=1.0,
|
||||
response_type=result.rule.response_type,
|
||||
target_kb_ids=result.rule.target_kb_ids,
|
||||
flow_id=str(result.rule.flow_id) if result.rule.flow_id else None,
|
||||
|
|
@ -608,6 +662,227 @@ async def _match_intent(
|
|||
return None
|
||||
|
||||
|
||||
async def _match_intent_hybrid(
|
||||
tenant_id: str,
|
||||
request: DialogueRequest,
|
||||
session: AsyncSession,
|
||||
clarification_engine: ClarificationEngine,
|
||||
) -> HybridIntentResult:
|
||||
"""
|
||||
[AC-CLARIFY] Hybrid intent matching with clarification support.
|
||||
|
||||
Returns HybridIntentResult with unified confidence and candidates.
|
||||
"""
|
||||
try:
|
||||
from app.services.intent.router import IntentRouter
|
||||
from app.services.intent.rule_service import IntentRuleService
|
||||
|
||||
rule_service = IntentRuleService(session)
|
||||
rules = await rule_service.get_enabled_rules_for_matching(tenant_id)
|
||||
|
||||
if not rules:
|
||||
return HybridIntentResult(
|
||||
intent=None,
|
||||
confidence=0.0,
|
||||
candidates=[],
|
||||
)
|
||||
|
||||
router = IntentRouter()
|
||||
|
||||
try:
|
||||
fusion_result = await router.match_hybrid(
|
||||
message=request.user_message,
|
||||
rules=rules,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
hybrid_result = HybridIntentResult.from_fusion_result(fusion_result)
|
||||
|
||||
logger.info(
|
||||
f"[AC-CLARIFY] Hybrid intent match: "
|
||||
f"intent={hybrid_result.intent.intent_name if hybrid_result.intent else None}, "
|
||||
f"confidence={hybrid_result.confidence:.3f}, "
|
||||
f"need_clarify={hybrid_result.need_clarify}"
|
||||
)
|
||||
|
||||
return hybrid_result
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-CLARIFY] Hybrid match failed, fallback to rule: {e}")
|
||||
|
||||
result = router.match(request.user_message, rules)
|
||||
|
||||
if result:
|
||||
confidence = clarification_engine.compute_confidence(
|
||||
rule_score=1.0,
|
||||
semantic_score=0.0,
|
||||
llm_score=0.0,
|
||||
)
|
||||
candidate = IntentCandidate(
|
||||
intent_id=str(result.rule.id),
|
||||
intent_name=result.rule.name,
|
||||
confidence=confidence,
|
||||
response_type=result.rule.response_type,
|
||||
target_kb_ids=result.rule.target_kb_ids,
|
||||
flow_id=str(result.rule.flow_id) if result.rule.flow_id else None,
|
||||
fixed_reply=result.rule.fixed_reply,
|
||||
transfer_message=result.rule.transfer_message,
|
||||
)
|
||||
return HybridIntentResult(
|
||||
intent=candidate,
|
||||
confidence=confidence,
|
||||
candidates=[candidate],
|
||||
)
|
||||
|
||||
return HybridIntentResult(
|
||||
intent=None,
|
||||
confidence=0.0,
|
||||
candidates=[],
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-CLARIFY] Intent match failed: {e}")
|
||||
return HybridIntentResult(
|
||||
intent=None,
|
||||
confidence=0.0,
|
||||
candidates=[],
|
||||
)
|
||||
|
||||
|
||||
async def _handle_clarification(
|
||||
tenant_id: str,
|
||||
request: DialogueRequest,
|
||||
hybrid_result: HybridIntentResult,
|
||||
clarification_engine: ClarificationEngine,
|
||||
trace: TraceInfo,
|
||||
session: AsyncSession,
|
||||
required_slots: list[str] | None = None,
|
||||
filled_slots: dict[str, Any] | None = None,
|
||||
) -> DialogueResponse | None:
|
||||
"""
|
||||
[AC-CLARIFY] Handle clarification logic.
|
||||
|
||||
Returns DialogueResponse if clarification is needed, None otherwise.
|
||||
"""
|
||||
existing_state = ClarifySessionManager.get_session(request.session_id)
|
||||
|
||||
if existing_state and not existing_state.is_max_retry():
|
||||
logger.info(
|
||||
f"[AC-CLARIFY] Processing clarify response: "
|
||||
f"session={request.session_id}, retry={existing_state.retry_count}"
|
||||
)
|
||||
|
||||
new_result = clarification_engine.process_clarify_response(
|
||||
user_message=request.user_message,
|
||||
state=existing_state,
|
||||
)
|
||||
|
||||
if not new_result.need_clarify:
|
||||
ClarifySessionManager.clear_session(request.session_id)
|
||||
|
||||
if new_result.intent:
|
||||
return None
|
||||
|
||||
clarify_prompt = clarification_engine.generate_clarify_prompt(existing_state)
|
||||
|
||||
return DialogueResponse(
|
||||
segments=[Segment(text=clarify_prompt, delay_after=0)],
|
||||
trace=TraceInfo(
|
||||
mode=ExecutionMode.FIXED,
|
||||
request_id=trace.request_id,
|
||||
generation_id=trace.generation_id,
|
||||
fallback_reason_code=f"clarify_retry_{existing_state.retry_count}",
|
||||
intent=existing_state.candidates[0].intent_name if existing_state.candidates else None,
|
||||
),
|
||||
)
|
||||
|
||||
should_clarify, clarify_state = clarification_engine.should_trigger_clarify(
|
||||
result=hybrid_result,
|
||||
required_slots=required_slots,
|
||||
filled_slots=filled_slots,
|
||||
)
|
||||
|
||||
if not should_clarify or not clarify_state:
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"[AC-CLARIFY] Clarification triggered: "
|
||||
f"session={request.session_id}, reason={clarify_state.reason}, "
|
||||
f"confidence={hybrid_result.confidence:.3f}"
|
||||
)
|
||||
|
||||
ClarifySessionManager.set_session(request.session_id, clarify_state)
|
||||
|
||||
clarify_prompt = clarification_engine.generate_clarify_prompt(clarify_state)
|
||||
|
||||
return DialogueResponse(
|
||||
segments=[Segment(text=clarify_prompt, delay_after=0)],
|
||||
trace=TraceInfo(
|
||||
mode=ExecutionMode.FIXED,
|
||||
request_id=trace.request_id,
|
||||
generation_id=trace.generation_id,
|
||||
fallback_reason_code=f"clarify_{clarify_state.reason.value}",
|
||||
intent=hybrid_result.intent.intent_name if hybrid_result.intent else None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def _check_hard_block(
|
||||
tenant_id: str,
|
||||
request: DialogueRequest,
|
||||
hybrid_result: HybridIntentResult,
|
||||
clarification_engine: ClarificationEngine,
|
||||
trace: TraceInfo,
|
||||
required_slots: list[str] | None = None,
|
||||
filled_slots: dict[str, Any] | None = None,
|
||||
) -> DialogueResponse | None:
|
||||
"""
|
||||
[AC-CLARIFY] Check hard block conditions.
|
||||
|
||||
Hard blocks:
|
||||
1. confidence < T_high: block entering new flow
|
||||
2. required_slots missing: block flow progression
|
||||
|
||||
Returns DialogueResponse if blocked, None otherwise.
|
||||
"""
|
||||
is_blocked, block_reason = clarification_engine.check_hard_block(
|
||||
result=hybrid_result,
|
||||
required_slots=required_slots,
|
||||
filled_slots=filled_slots,
|
||||
)
|
||||
|
||||
if not is_blocked:
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"[AC-CLARIFY] Hard block triggered: "
|
||||
f"session={request.session_id}, reason={block_reason}, "
|
||||
f"confidence={hybrid_result.confidence:.3f}"
|
||||
)
|
||||
|
||||
clarify_state = ClarifyState(
|
||||
reason=block_reason,
|
||||
asked_slot=required_slots[0] if required_slots and block_reason == ClarifyReason.MISSING_SLOT else None,
|
||||
candidates=hybrid_result.candidates,
|
||||
)
|
||||
|
||||
ClarifySessionManager.set_session(request.session_id, clarify_state)
|
||||
clarification_engine._metrics.record_clarify_trigger()
|
||||
|
||||
clarify_prompt = clarification_engine.generate_clarify_prompt(clarify_state)
|
||||
|
||||
return DialogueResponse(
|
||||
segments=[Segment(text=clarify_prompt, delay_after=0)],
|
||||
trace=TraceInfo(
|
||||
mode=ExecutionMode.FIXED,
|
||||
request_id=trace.request_id,
|
||||
generation_id=trace.generation_id,
|
||||
fallback_reason_code=f"hard_block_{block_reason.value}",
|
||||
intent=hybrid_result.intent.intent_name if hybrid_result.intent else None,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_legacy_response(
|
||||
tenant_id: str,
|
||||
request: DialogueRequest,
|
||||
|
|
@ -793,18 +1068,32 @@ async def _execute_agent_mode(
|
|||
session: AsyncSession | None = None,
|
||||
tool_registry: ToolRegistry | None = None,
|
||||
) -> DialogueResponse:
|
||||
"""[AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-IDMP-13] Execute agent mode with ReAct loop, KB tool, and memory recall."""
|
||||
"""
|
||||
[AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-IDMP-13, AC-MRS-SLOT-META-03]
|
||||
Execute agent mode with ReAct loop, KB tool, memory recall, and slot state aggregation.
|
||||
"""
|
||||
from app.services.llm.factory import get_llm_config_manager
|
||||
from app.services.mid.slot_state_aggregator import SlotStateAggregator
|
||||
|
||||
logger.info(
|
||||
f"[DEBUG-AGENT] Starting _execute_agent_mode: tenant={tenant_id}, "
|
||||
f"session={request.session_id}, scene={request.scene}, user_id={request.user_id}, "
|
||||
f"user_message={request.user_message[:100] if request.user_message else 'None'}..."
|
||||
)
|
||||
|
||||
try:
|
||||
llm_manager = get_llm_config_manager()
|
||||
llm_client = llm_manager.get_client()
|
||||
logger.info(f"[DEBUG-AGENT] LLM client obtained successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-MARH-07] Failed to get LLM client: {e}")
|
||||
llm_client = None
|
||||
|
||||
base_context = {"history": [h.model_dump() for h in request.history]} if request.history else {}
|
||||
|
||||
if request.scene:
|
||||
base_context["scene"] = request.scene
|
||||
|
||||
if interrupt_ctx and interrupt_ctx.consumed:
|
||||
base_context["interrupted_content"] = interrupt_ctx.interrupted_content
|
||||
base_context["interrupted_segment_ids"] = interrupt_ctx.interrupted_segment_ids
|
||||
|
|
@ -813,8 +1102,50 @@ async def _execute_agent_mode(
|
|||
f"{len(interrupt_ctx.interrupted_content or '')} chars"
|
||||
)
|
||||
|
||||
# [AC-MRS-SLOT-META-03] 初始化槽位状态聚合器
|
||||
slot_state = None
|
||||
memory_context = ""
|
||||
memory_missing_slots: list[str] = []
|
||||
scene_slot_context = None
|
||||
|
||||
flow_status = None
|
||||
is_flow_runtime = False
|
||||
if session:
|
||||
from app.services.flow.engine import FlowEngine
|
||||
flow_engine = FlowEngine(session)
|
||||
flow_status = await flow_engine.get_flow_status(tenant_id, request.session_id)
|
||||
is_flow_runtime = bool(flow_status and flow_status.get("status") == "active")
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDMP-02] 运行轨道判定: session={request.session_id}, "
|
||||
f"is_flow_runtime={is_flow_runtime}, flow_id={flow_status.get('flow_id') if flow_status else None}, "
|
||||
f"current_step={flow_status.get('current_step') if flow_status else None}"
|
||||
)
|
||||
|
||||
# [AC-SCENE-SLOT-02] 加载场景槽位包(仅流程态加载)
|
||||
if is_flow_runtime and session and request.scene:
|
||||
from app.services.mid.scene_slot_bundle_loader import SceneSlotBundleLoader
|
||||
|
||||
scene_loader = SceneSlotBundleLoader(session)
|
||||
scene_slot_context = await scene_loader.load_scene_context(
|
||||
tenant_id=tenant_id,
|
||||
scene_key=request.scene,
|
||||
)
|
||||
|
||||
if scene_slot_context:
|
||||
logger.info(
|
||||
f"[AC-SCENE-SLOT-02] 场景槽位包加载成功: "
|
||||
f"scene={request.scene}, required={len(scene_slot_context.required_slots)}, "
|
||||
f"optional={len(scene_slot_context.optional_slots)}, "
|
||||
f"threshold={scene_slot_context.completion_threshold}"
|
||||
)
|
||||
base_context["scene_slot_context"] = {
|
||||
"scene_key": scene_slot_context.scene_key,
|
||||
"scene_name": scene_slot_context.scene_name,
|
||||
"required_slots": scene_slot_context.get_required_slot_keys(),
|
||||
"optional_slots": scene_slot_context.get_optional_slot_keys(),
|
||||
}
|
||||
|
||||
if session and request.user_id:
|
||||
memory_recall_tool = MemoryRecallTool(
|
||||
session=session,
|
||||
|
|
@ -855,11 +1186,111 @@ async def _execute_agent_mode(
|
|||
f"[AC-IDMP-13] Memory recall fallback: reason={memory_result.fallback_reason_code}"
|
||||
)
|
||||
|
||||
# [AC-MRS-SLOT-META-03] 仅流程态执行槽位聚合/提取
|
||||
if is_flow_runtime:
|
||||
slot_aggregator = SlotStateAggregator(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
session_id=request.session_id,
|
||||
)
|
||||
slot_state = await slot_aggregator.aggregate(
|
||||
memory_slots=memory_result.slots,
|
||||
current_input_slots=None, # 可从 request 中解析
|
||||
context=base_context,
|
||||
scene_slot_context=scene_slot_context, # [AC-SCENE-SLOT-02] 传入场景槽位上下文
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-META-03] 流程态槽位聚合完成: "
|
||||
f"filled={len(slot_state.filled_slots)}, "
|
||||
f"missing={len(slot_state.missing_required_slots)}, "
|
||||
f"mappings={slot_state.slot_to_field_map}"
|
||||
)
|
||||
|
||||
# [AC-MRS-SLOT-EXTRACT-01] 自动提取槽位
|
||||
if slot_state and slot_state.missing_required_slots:
|
||||
from app.services.mid.slot_extraction_integration import SlotExtractionIntegration
|
||||
|
||||
extraction_integration = SlotExtractionIntegration(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
session_id=request.session_id,
|
||||
)
|
||||
|
||||
extraction_result = await extraction_integration.extract_and_fill(
|
||||
user_input=request.user_message,
|
||||
slot_state=slot_state,
|
||||
)
|
||||
|
||||
if extraction_result.extracted_slots:
|
||||
slot_state = await slot_aggregator.aggregate(
|
||||
memory_slots=memory_result.slots,
|
||||
current_input_slots=extraction_result.extracted_slots,
|
||||
context=base_context,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-EXTRACT-01] 流程态自动提槽完成: "
|
||||
f"extracted={list(extraction_result.extracted_slots.keys())}, "
|
||||
f"time_ms={extraction_result.total_execution_time_ms:.2f}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-META-03] 当前为通用问答轨道,跳过槽位聚合与缺槽追问: "
|
||||
f"session={request.session_id}"
|
||||
)
|
||||
|
||||
kb_hits = []
|
||||
kb_success = False
|
||||
kb_fallback_reason = None
|
||||
kb_applied_filter = {}
|
||||
kb_missing_slots = []
|
||||
kb_dynamic_result = None
|
||||
step_kb_binding_trace: dict[str, Any] | None = None
|
||||
|
||||
# [Step-KB-Binding] 获取当前流程步骤的 KB 配置(仅流程态)
|
||||
step_kb_config = None
|
||||
if is_flow_runtime and session and flow_status:
|
||||
current_step_no = flow_status.get("current_step")
|
||||
flow_id = flow_status.get("flow_id")
|
||||
|
||||
if flow_id and current_step_no:
|
||||
from app.models.entities import ScriptFlow
|
||||
from sqlalchemy import select
|
||||
|
||||
stmt = select(ScriptFlow).where(ScriptFlow.id == flow_id)
|
||||
result = await session.execute(stmt)
|
||||
flow = result.scalar_one_or_none()
|
||||
|
||||
if flow and flow.steps:
|
||||
step_idx = current_step_no - 1
|
||||
if 0 <= step_idx < len(flow.steps):
|
||||
current_step = flow.steps[step_idx]
|
||||
|
||||
# 构建 StepKbConfig
|
||||
from app.services.mid.kb_search_dynamic_tool import StepKbConfig
|
||||
step_kb_config = StepKbConfig(
|
||||
allowed_kb_ids=current_step.get("allowed_kb_ids"),
|
||||
preferred_kb_ids=current_step.get("preferred_kb_ids"),
|
||||
kb_query_hint=current_step.get("kb_query_hint"),
|
||||
max_kb_calls=current_step.get("max_kb_calls_per_step", 1),
|
||||
step_id=f"{flow_id}_step_{current_step_no}",
|
||||
)
|
||||
|
||||
step_kb_binding_trace = {
|
||||
"flow_id": str(flow_id),
|
||||
"flow_name": flow_status.get("flow_name"),
|
||||
"current_step": current_step_no,
|
||||
"step_id": step_kb_config.step_id,
|
||||
"allowed_kb_ids": step_kb_config.allowed_kb_ids,
|
||||
"preferred_kb_ids": step_kb_config.preferred_kb_ids,
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"[Step-KB-Binding] 步骤知识库配置加载成功: "
|
||||
f"flow={flow_status.get('flow_name')}, step={current_step_no}, "
|
||||
f"allowed_kb_ids={step_kb_config.allowed_kb_ids}"
|
||||
)
|
||||
|
||||
if session and tool_registry:
|
||||
kb_tool = KbSearchDynamicTool(
|
||||
|
|
@ -868,17 +1299,21 @@ async def _execute_agent_mode(
|
|||
config=KbSearchDynamicConfig(
|
||||
enabled=True,
|
||||
top_k=5,
|
||||
timeout_ms=2000,
|
||||
timeout_ms=10000,
|
||||
min_score_threshold=0.5,
|
||||
),
|
||||
)
|
||||
|
||||
# [AC-MRS-SLOT-META-03] 传入 slot_state 进行 KB 检索
|
||||
# [Step-KB-Binding] 传入 step_kb_config 进行步骤级别的 KB 约束
|
||||
kb_dynamic_result = await kb_tool.execute(
|
||||
query=request.user_message,
|
||||
tenant_id=tenant_id,
|
||||
scene="open_consult",
|
||||
top_k=5,
|
||||
context=base_context,
|
||||
slot_state=slot_state,
|
||||
step_kb_config=step_kb_config,
|
||||
slot_policy="flow_strict" if is_flow_runtime else "agent_relaxed",
|
||||
)
|
||||
|
||||
kb_success = kb_dynamic_result.success
|
||||
|
|
@ -894,10 +1329,44 @@ async def _execute_agent_mode(
|
|||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-05] KB dynamic search: success={kb_success}, "
|
||||
f"[AC-MARH-05] KB动态检索完成: success={kb_success}, "
|
||||
f"hits={len(kb_hits)}, filter={kb_applied_filter}, "
|
||||
f"missing_slots={kb_missing_slots}"
|
||||
f"missing_slots={kb_missing_slots}, track={'flow' if is_flow_runtime else 'agent'}"
|
||||
)
|
||||
|
||||
# [AC-MRS-SLOT-META-03] 处理缺失必填槽位 -> 追问闭环(仅流程态)
|
||||
if is_flow_runtime and kb_fallback_reason == "MISSING_REQUIRED_SLOTS" and kb_missing_slots:
|
||||
ask_back_text = await _generate_ask_back_for_missing_slots(
|
||||
slot_state=slot_state,
|
||||
missing_slots=kb_missing_slots,
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
session_id=request.session_id,
|
||||
scene_slot_context=scene_slot_context, # [AC-SCENE-SLOT-02] 传入场景槽位上下文
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-META-03] 流程态缺槽追问文案已生成: "
|
||||
f"{ask_back_text[:50]}..."
|
||||
)
|
||||
|
||||
return DialogueResponse(
|
||||
segments=[Segment(text=ask_back_text, delay_after=0)],
|
||||
trace=TraceInfo(
|
||||
mode=ExecutionMode.AGENT,
|
||||
request_id=trace.request_id,
|
||||
generation_id=trace.generation_id,
|
||||
fallback_reason_code="missing_required_slots",
|
||||
kb_tool_called=True,
|
||||
kb_hit=False,
|
||||
# [AC-SCENE-SLOT-02] 场景槽位追踪
|
||||
scene=request.scene,
|
||||
scene_slot_context=base_context.get("scene_slot_context"),
|
||||
missing_slots=[s.get("slot_key") for s in kb_missing_slots] if kb_missing_slots else None,
|
||||
ask_back_triggered=True,
|
||||
slot_sources=slot_state.slot_sources if slot_state else None,
|
||||
),
|
||||
)
|
||||
else:
|
||||
kb_result = await default_kb_tool_runner.execute(
|
||||
tenant_id=tenant_id,
|
||||
|
|
@ -933,7 +1402,18 @@ async def _execute_agent_mode(
|
|||
timeout_governor=timeout_governor,
|
||||
llm_client=llm_client,
|
||||
tool_registry=tool_registry,
|
||||
template_service=PromptTemplateService,
|
||||
variable_resolver=VariableResolver(),
|
||||
tenant_id=tenant_id,
|
||||
user_id=request.user_id,
|
||||
session_id=request.session_id,
|
||||
scene=request.scene,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[DEBUG-AGENT] Calling orchestrator.execute with: "
|
||||
f"user_message={request.user_message[:100] if request.user_message else 'None'}, "
|
||||
f"context_keys={list(base_context.keys())}, llm_client={llm_client is not None}"
|
||||
)
|
||||
|
||||
final_answer, react_ctx, agent_trace = await orchestrator.execute(
|
||||
|
|
@ -941,6 +1421,12 @@ async def _execute_agent_mode(
|
|||
context=base_context,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[DEBUG-AGENT] orchestrator.execute completed: "
|
||||
f"final_answer={final_answer[:200] if final_answer else 'None'}, "
|
||||
f"iterations={react_ctx.iteration}, tool_calls={len(react_ctx.tool_calls) if react_ctx.tool_calls else 0}"
|
||||
)
|
||||
|
||||
runtime_observer.record_react(request_id, react_ctx.iteration, react_ctx.tool_calls)
|
||||
|
||||
trace_logger.update_trace(
|
||||
|
|
@ -964,10 +1450,114 @@ async def _execute_agent_mode(
|
|||
kb_tool_called=True,
|
||||
kb_hit=kb_success and len(kb_hits) > 0,
|
||||
fallback_reason_code=kb_fallback_reason,
|
||||
# [AC-SCENE-SLOT-02] 场景槽位追踪
|
||||
scene=request.scene,
|
||||
scene_slot_context=base_context.get("scene_slot_context"),
|
||||
missing_slots=[s.get("slot_key") for s in kb_missing_slots] if kb_missing_slots else None,
|
||||
ask_back_triggered=False,
|
||||
slot_sources=slot_state.slot_sources if slot_state else None,
|
||||
kb_filter_sources=kb_dynamic_result.filter_sources if kb_dynamic_result else None,
|
||||
# [Step-KB-Binding] 步骤知识库绑定追踪
|
||||
step_kb_binding=kb_dynamic_result.step_kb_binding if kb_dynamic_result else step_kb_binding_trace,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
async def _generate_ask_back_for_missing_slots(
|
||||
slot_state: Any,
|
||||
missing_slots: list[dict[str, str]],
|
||||
session: AsyncSession,
|
||||
tenant_id: str,
|
||||
session_id: str | None = None,
|
||||
max_ask_back_slots: int = 2,
|
||||
scene_slot_context: Any = None, # [AC-SCENE-SLOT-02] 场景槽位上下文
|
||||
) -> str:
|
||||
"""
|
||||
[AC-MRS-SLOT-META-03, AC-MRS-SLOT-ASKBACK-01] 为缺失的必填槽位生成追问响应
|
||||
[AC-SCENE-SLOT-02] 支持场景槽位包配置的追问策略
|
||||
|
||||
支持批量追问多个缺失槽位,优先追问必填槽位
|
||||
"""
|
||||
if not missing_slots:
|
||||
return "请提供更多信息以便我更好地帮助您。"
|
||||
|
||||
# [AC-SCENE-SLOT-02] 如果有场景槽位上下文,使用场景配置的追问策略
|
||||
if scene_slot_context:
|
||||
ask_back_order = getattr(scene_slot_context, 'ask_back_order', 'priority')
|
||||
|
||||
if ask_back_order == "parallel":
|
||||
prompts = []
|
||||
for missing in missing_slots[:max_ask_back_slots]:
|
||||
ask_back_prompt = missing.get("ask_back_prompt")
|
||||
if ask_back_prompt:
|
||||
prompts.append(ask_back_prompt)
|
||||
else:
|
||||
slot_key = missing.get("slot_key", "相关信息")
|
||||
prompts.append(f"您的{slot_key}")
|
||||
|
||||
if len(prompts) == 1:
|
||||
return prompts[0]
|
||||
elif len(prompts) == 2:
|
||||
return f"为了更好地为您服务,请告诉我{prompts[0]}和{prompts[1]}。"
|
||||
else:
|
||||
all_but_last = "、".join(prompts[:-1])
|
||||
return f"为了更好地为您服务,请告诉我{all_but_last},以及{prompts[-1]}。"
|
||||
|
||||
if len(missing_slots) == 1:
|
||||
first_missing = missing_slots[0]
|
||||
ask_back_prompt = first_missing.get("ask_back_prompt")
|
||||
if ask_back_prompt:
|
||||
return ask_back_prompt
|
||||
slot_key = first_missing.get("slot_key", "相关信息")
|
||||
label = first_missing.get("label", slot_key)
|
||||
return f"为了更好地为您提供帮助,请告诉我您的{label}。"
|
||||
|
||||
try:
|
||||
from app.services.mid.batch_ask_back_service import (
|
||||
BatchAskBackConfig,
|
||||
BatchAskBackService,
|
||||
)
|
||||
|
||||
config = BatchAskBackConfig(
|
||||
max_ask_back_slots_per_turn=max_ask_back_slots,
|
||||
prefer_required=True,
|
||||
merge_prompts=True,
|
||||
)
|
||||
|
||||
ask_back_service = BatchAskBackService(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
session_id=session_id or "",
|
||||
config=config,
|
||||
)
|
||||
|
||||
result = await ask_back_service.generate_batch_ask_back(
|
||||
missing_slots=missing_slots,
|
||||
)
|
||||
|
||||
if result.has_ask_back():
|
||||
return result.get_prompt()
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-MRS-SLOT-ASKBACK-01] Batch ask-back failed, fallback to single: {e}")
|
||||
|
||||
prompts = []
|
||||
for missing in missing_slots[:max_ask_back_slots]:
|
||||
ask_back_prompt = missing.get("ask_back_prompt")
|
||||
if ask_back_prompt:
|
||||
prompts.append(ask_back_prompt)
|
||||
else:
|
||||
label = missing.get("label", missing.get("slot_key", "相关信息"))
|
||||
prompts.append(f"您的{label}")
|
||||
|
||||
if len(prompts) == 1:
|
||||
return prompts[0]
|
||||
elif len(prompts) == 2:
|
||||
return f"为了更好地为您服务,请告诉我{prompts[0]}和{prompts[1]}。"
|
||||
else:
|
||||
all_but_last = "、".join(prompts[:-1])
|
||||
return f"为了更好地为您服务,请告诉我{all_but_last},以及{prompts[-1]}。"
|
||||
|
||||
|
||||
async def _execute_micro_flow_mode(
|
||||
tenant_id: str,
|
||||
request: DialogueRequest,
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from typing import Annotated
|
|||
|
||||
from fastapi import APIRouter, Depends, Path
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_session
|
||||
|
|
@ -26,6 +27,82 @@ router = APIRouter(prefix="/mid", tags=["Mid Platform Sessions"])
|
|||
_session_modes: dict[str, SessionMode] = {}
|
||||
|
||||
|
||||
class CancelFlowResponse(BaseModel):
|
||||
"""Response for cancel flow operation."""
|
||||
|
||||
success: bool
|
||||
message: str
|
||||
session_id: str
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{sessionId}/cancel-flow",
|
||||
operation_id="cancelActiveFlow",
|
||||
summary="Cancel active flow",
|
||||
description="""
|
||||
Cancel the active flow for a session.
|
||||
|
||||
Use this when you encounter "Session already has an active flow" error.
|
||||
""",
|
||||
responses={
|
||||
200: {"description": "Flow cancelled successfully", "model": CancelFlowResponse},
|
||||
404: {"description": "No active flow found"},
|
||||
},
|
||||
)
|
||||
async def cancel_active_flow(
|
||||
sessionId: Annotated[str, Path(description="Session ID")],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
) -> CancelFlowResponse:
|
||||
"""
|
||||
Cancel the active flow for a session.
|
||||
|
||||
This endpoint allows you to cancel any active flow instance
|
||||
so that a new flow can be started.
|
||||
"""
|
||||
tenant_id = get_tenant_id()
|
||||
|
||||
if not tenant_id:
|
||||
from app.core.exceptions import MissingTenantIdException
|
||||
raise MissingTenantIdException()
|
||||
|
||||
logger.info(
|
||||
f"[Cancel Flow] Cancelling active flow: tenant={tenant_id}, session={sessionId}"
|
||||
)
|
||||
|
||||
try:
|
||||
from app.services.flow.engine import FlowEngine
|
||||
|
||||
flow_engine = FlowEngine(session)
|
||||
cancelled = await flow_engine.cancel_flow(
|
||||
tenant_id=tenant_id,
|
||||
session_id=sessionId,
|
||||
reason="User requested cancellation via API",
|
||||
)
|
||||
|
||||
if cancelled:
|
||||
logger.info(f"[Cancel Flow] Flow cancelled: session={sessionId}")
|
||||
return CancelFlowResponse(
|
||||
success=True,
|
||||
message="Active flow cancelled successfully",
|
||||
session_id=sessionId,
|
||||
)
|
||||
else:
|
||||
logger.info(f"[Cancel Flow] No active flow found: session={sessionId}")
|
||||
return CancelFlowResponse(
|
||||
success=True,
|
||||
message="No active flow found for this session",
|
||||
session_id=sessionId,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[Cancel Flow] Failed to cancel flow: {e}")
|
||||
return CancelFlowResponse(
|
||||
success=False,
|
||||
message=f"Failed to cancel flow: {str(e)}",
|
||||
session_id=sessionId,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/sessions/{sessionId}/mode",
|
||||
operation_id="switchSessionMode",
|
||||
|
|
|
|||
|
|
@ -73,6 +73,7 @@ class DialogueRequest(BaseModel):
|
|||
history: list[HistoryMessage] = Field(default_factory=list, description="已送达历史")
|
||||
interrupted_segments: list[InterruptedSegment] | None = Field(default=None, description="打断的分段")
|
||||
feature_flags: FeatureFlags | None = Field(default=None, description="特性开关")
|
||||
scene: str | None = Field(default=None, description="场景标识,用于KB过滤,如 'open_consult', 'after_sale'")
|
||||
|
||||
|
||||
class Segment(BaseModel):
|
||||
|
|
@ -127,6 +128,7 @@ class TraceInfo(BaseModel):
|
|||
)
|
||||
tools_used: list[str] | None = Field(default=None, description="使用的工具列表")
|
||||
tool_calls: list[ToolCallTraceModel] | None = Field(default=None, description="工具调用追踪")
|
||||
step_kb_binding: dict[str, Any] | None = Field(default=None, description="[Step-KB-Binding] 步骤知识库绑定信息")
|
||||
|
||||
|
||||
class DialogueResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -86,6 +86,7 @@ class DialogueRequest(BaseModel):
|
|||
humanize_config: HumanizeConfigRequest | None = Field(
|
||||
default=None, description="Humanize config for segment delay"
|
||||
)
|
||||
scene: str | None = Field(default=None, description="Scene identifier for KB filtering, e.g., 'open_consult', 'after_sale'")
|
||||
|
||||
|
||||
class Segment(BaseModel):
|
||||
|
|
@ -122,6 +123,8 @@ class ToolCallTrace(BaseModel):
|
|||
error_code: str | None = Field(default=None, description="Error code if failed")
|
||||
args_digest: str | None = Field(default=None, description="Arguments digest for logging")
|
||||
result_digest: str | None = Field(default=None, description="Result digest for logging")
|
||||
arguments: dict[str, Any] | None = Field(default=None, description="Full tool call arguments")
|
||||
result: Any = Field(default=None, description="Full tool call result")
|
||||
|
||||
|
||||
class SegmentStats(BaseModel):
|
||||
|
|
@ -133,7 +136,7 @@ class SegmentStats(BaseModel):
|
|||
|
||||
class TraceInfo(BaseModel):
|
||||
"""[AC-MARH-02, AC-MARH-03, AC-MARH-05, AC-MARH-06, AC-MARH-07, AC-MARH-11,
|
||||
AC-MARH-12, AC-MARH-18, AC-MARH-19, AC-MARH-20] Trace info for observability."""
|
||||
AC-MARH-12, AC-MARH-18, AC-MARH-19, AC-MARH-20, AC-SCENE-SLOT-02] Trace info for observability."""
|
||||
mode: ExecutionMode = Field(..., description="Execution mode")
|
||||
intent: str | None = Field(default=None, description="Matched intent")
|
||||
request_id: str | None = Field(
|
||||
|
|
@ -156,6 +159,17 @@ class TraceInfo(BaseModel):
|
|||
high_risk_policy_set: list[HighRiskScenario] | None = Field(default=None, description="Active high-risk policy set")
|
||||
tools_used: list[str] | None = Field(default=None, description="Tools used in this request")
|
||||
tool_calls: list[ToolCallTrace] | None = Field(default=None, description="Tool call traces")
|
||||
duration_ms: int = Field(default=0, ge=0, description="Execution duration in milliseconds")
|
||||
created_at: str | None = Field(default=None, description="Creation timestamp")
|
||||
# [AC-SCENE-SLOT-02] 场景槽位追踪字段
|
||||
scene: str | None = Field(default=None, description="当前场景标识")
|
||||
scene_slot_context: dict[str, Any] | None = Field(default=None, description="场景槽位上下文信息")
|
||||
missing_slots: list[str] | None = Field(default=None, description="缺失的必填槽位列表")
|
||||
ask_back_triggered: bool | None = Field(default=False, description="是否触发了追问")
|
||||
slot_sources: dict[str, str] | None = Field(default=None, description="槽位值来源映射")
|
||||
kb_filter_sources: dict[str, str] | None = Field(default=None, description="KB 过滤条件来源映射")
|
||||
# [Step-KB-Binding] 步骤知识库绑定追踪
|
||||
step_kb_binding: dict[str, Any] | None = Field(default=None, description="步骤知识库绑定信息,包含 step_id, allowed_kb_ids, used_kb_ids 等")
|
||||
|
||||
|
||||
class DialogueResponse(BaseModel):
|
||||
|
|
|
|||
|
|
@ -17,6 +17,18 @@ from .memory_adapter import MemoryAdapter, UserMemory
|
|||
from .default_kb_tool_runner import DefaultKbToolRunner, KbToolResult, KbToolConfig, get_default_kb_tool_runner
|
||||
from .segment_humanizer import SegmentHumanizer, HumanizeConfig, LengthBucket, get_segment_humanizer
|
||||
from .runtime_observer import RuntimeObserver, RuntimeContext, get_runtime_observer
|
||||
from .slot_validation_service import (
|
||||
SlotValidationService,
|
||||
ValidationResult,
|
||||
SlotValidationError,
|
||||
BatchValidationResult,
|
||||
SlotValidationErrorCode,
|
||||
)
|
||||
from .slot_manager import (
|
||||
SlotManager,
|
||||
SlotWriteResult,
|
||||
create_slot_manager,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"PolicyRouter",
|
||||
|
|
@ -54,4 +66,12 @@ __all__ = [
|
|||
"RuntimeObserver",
|
||||
"RuntimeContext",
|
||||
"get_runtime_observer",
|
||||
"SlotValidationService",
|
||||
"ValidationResult",
|
||||
"SlotValidationError",
|
||||
"BatchValidationResult",
|
||||
"SlotValidationErrorCode",
|
||||
"SlotManager",
|
||||
"SlotWriteResult",
|
||||
"create_slot_manager",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -17,15 +17,20 @@ import logging
|
|||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.mid.schemas import ToolCallStatus, ToolCallTrace, ToolType
|
||||
from app.services.mid.metadata_filter_builder import (
|
||||
FilterBuildResult,
|
||||
FilterFieldInfo,
|
||||
MetadataFilterBuilder,
|
||||
)
|
||||
from app.services.mid.slot_state_aggregator import (
|
||||
SlotState,
|
||||
SlotStateAggregator,
|
||||
)
|
||||
from app.services.mid.timeout_governor import TimeoutGovernor
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
@ -37,6 +42,9 @@ DEFAULT_TOP_K = 5
|
|||
DEFAULT_TIMEOUT_MS = 2000
|
||||
KB_SEARCH_DYNAMIC_TOOL_NAME = "kb_search_dynamic"
|
||||
|
||||
_TOOL_SCHEMA_CACHE: dict[str, tuple[float, dict[str, Any]]] = {}
|
||||
_TOOL_SCHEMA_CACHE_TTL_SECONDS = 300 # 5 minutes
|
||||
|
||||
|
||||
@dataclass
|
||||
class KbSearchDynamicResult:
|
||||
|
|
@ -46,9 +54,21 @@ class KbSearchDynamicResult:
|
|||
applied_filter: dict[str, Any] = field(default_factory=dict)
|
||||
missing_required_slots: list[dict[str, str]] = field(default_factory=list)
|
||||
filter_debug: dict[str, Any] = field(default_factory=dict)
|
||||
filter_sources: dict[str, str] = field(default_factory=dict) # [AC-SCENE-SLOT-02] 过滤条件来源
|
||||
fallback_reason_code: str | None = None
|
||||
duration_ms: int = 0
|
||||
tool_trace: ToolCallTrace | None = None
|
||||
step_kb_binding: dict[str, Any] | None = None # [Step-KB-Binding] 步骤知识库绑定信息
|
||||
|
||||
|
||||
@dataclass
|
||||
class StepKbConfig:
|
||||
"""[Step-KB-Binding] 步骤级别的知识库配置。"""
|
||||
allowed_kb_ids: list[str] | None = None
|
||||
preferred_kb_ids: list[str] | None = None
|
||||
kb_query_hint: str | None = None
|
||||
max_kb_calls: int = 1
|
||||
step_id: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -86,12 +106,14 @@ class KbSearchDynamicTool:
|
|||
session: AsyncSession,
|
||||
timeout_governor: TimeoutGovernor | None = None,
|
||||
config: KbSearchDynamicConfig | None = None,
|
||||
slot_state_aggregator: SlotStateAggregator | None = None,
|
||||
):
|
||||
self._session = session
|
||||
self._timeout_governor = timeout_governor or TimeoutGovernor()
|
||||
self._config = config or KbSearchDynamicConfig()
|
||||
self._vector_retriever = None
|
||||
self._filter_builder: MetadataFilterBuilder | None = None
|
||||
self._slot_state_aggregator = slot_state_aggregator
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
|
|
@ -140,6 +162,128 @@ class KbSearchDynamicTool:
|
|||
},
|
||||
}
|
||||
|
||||
async def get_dynamic_tool_schema(self, tenant_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
获取动态生成的工具 Schema,包含租户的元数据过滤字段。
|
||||
|
||||
使用缓存机制,避免每次都查询数据库。
|
||||
只显示关联知识库的元数据过滤字段(field_roles 包含 resource_filter)。
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
|
||||
Returns:
|
||||
动态生成的工具 Schema
|
||||
"""
|
||||
import time
|
||||
current_time = time.time()
|
||||
|
||||
cache_key = f"tool_schema:{tenant_id}"
|
||||
if cache_key in _TOOL_SCHEMA_CACHE:
|
||||
cached_time, cached_schema = _TOOL_SCHEMA_CACHE[cache_key]
|
||||
if current_time - cached_time < _TOOL_SCHEMA_CACHE_TTL_SECONDS:
|
||||
logger.debug(f"[AC-MARH-05] Tool schema cache hit for tenant={tenant_id}")
|
||||
return cached_schema
|
||||
|
||||
logger.info(f"[AC-MARH-05] Building dynamic tool schema for tenant={tenant_id}")
|
||||
|
||||
base_properties = {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "检索查询文本",
|
||||
},
|
||||
"top_k": {
|
||||
"type": "integer",
|
||||
"description": "返回结果数量",
|
||||
"default": DEFAULT_TOP_K,
|
||||
},
|
||||
}
|
||||
|
||||
required_fields = ["query"]
|
||||
context_properties = {}
|
||||
|
||||
try:
|
||||
if self._filter_builder is None:
|
||||
self._filter_builder = MetadataFilterBuilder(self._session)
|
||||
|
||||
filterable_fields = await self._filter_builder._get_filterable_fields(tenant_id)
|
||||
|
||||
for field_info in filterable_fields:
|
||||
field_schema = self._build_field_schema(field_info)
|
||||
context_properties[field_info.field_key] = field_schema
|
||||
|
||||
if field_info.required:
|
||||
required_fields.append(field_info.field_key)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-05] Dynamic schema built: tenant={tenant_id}, "
|
||||
f"context_fields={len(context_properties)}, required={required_fields}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-MARH-05] Failed to get filterable fields: {e}, using base schema")
|
||||
|
||||
if context_properties:
|
||||
base_properties["context"] = {
|
||||
"type": "object",
|
||||
"description": "过滤条件,根据用户意图选择合适的字段传递",
|
||||
"properties": context_properties,
|
||||
}
|
||||
else:
|
||||
base_properties["context"] = {
|
||||
"type": "object",
|
||||
"description": "过滤条件(当前租户未配置元数据字段)",
|
||||
}
|
||||
|
||||
schema = {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": base_properties,
|
||||
"required": required_fields,
|
||||
},
|
||||
}
|
||||
|
||||
_TOOL_SCHEMA_CACHE[cache_key] = (current_time, schema)
|
||||
|
||||
return schema
|
||||
|
||||
def _build_field_schema(self, field_info: "FilterFieldInfo") -> dict[str, Any]:
|
||||
"""
|
||||
根据字段信息构建 JSON Schema。
|
||||
|
||||
Args:
|
||||
field_info: 字段信息
|
||||
|
||||
Returns:
|
||||
字段的 JSON Schema
|
||||
"""
|
||||
schema: dict[str, Any] = {
|
||||
"description": field_info.label or field_info.field_key,
|
||||
}
|
||||
|
||||
if field_info.required:
|
||||
schema["description"] += "(必填)"
|
||||
|
||||
field_type = field_info.field_type.lower() if field_info.field_type else "string"
|
||||
|
||||
if field_type in ("enum", "select", "array_enum", "multi_select"):
|
||||
schema["type"] = "string"
|
||||
if field_info.options:
|
||||
schema["enum"] = field_info.options
|
||||
schema["description"] += f",可选值:{', '.join(field_info.options)}"
|
||||
elif field_type in ("number", "integer", "float"):
|
||||
schema["type"] = "number"
|
||||
elif field_type == "boolean":
|
||||
schema["type"] = "boolean"
|
||||
else:
|
||||
schema["type"] = "string"
|
||||
|
||||
if field_info.default_value is not None:
|
||||
schema["default"] = field_info.default_value
|
||||
|
||||
return schema
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
query: str,
|
||||
|
|
@ -147,16 +291,24 @@ class KbSearchDynamicTool:
|
|||
scene: str = "open_consult",
|
||||
top_k: int | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
slot_state: SlotState | None = None,
|
||||
step_kb_config: StepKbConfig | None = None,
|
||||
slot_policy: Literal["flow_strict", "agent_relaxed"] = "flow_strict",
|
||||
) -> KbSearchDynamicResult:
|
||||
"""
|
||||
[AC-MARH-05] 执行 KB 动态检索。
|
||||
[AC-MRS-SLOT-META-02] 支持槽位状态聚合和过滤构建优先级
|
||||
[Step-KB-Binding] 支持步骤级别的知识库约束
|
||||
|
||||
Args:
|
||||
query: 检索查询
|
||||
tenant_id: 租户 ID
|
||||
scene: 场景标识
|
||||
scene: 场景标识(默认值,会被 context 中的 scene 覆盖)
|
||||
top_k: 返回数量
|
||||
context: 上下文(包含动态过滤值)
|
||||
context: 上下文(包含动态过滤值,包括 scene)
|
||||
slot_state: 预聚合的槽位状态(可选,优先使用)
|
||||
step_kb_config: 步骤级别的知识库配置(可选)
|
||||
slot_policy: 槽位策略(flow_strict=流程严格模式,agent_relaxed=通用问答宽松模式)
|
||||
|
||||
Returns:
|
||||
KbSearchDynamicResult 包含检索结果和追踪信息
|
||||
|
|
@ -171,82 +323,150 @@ class KbSearchDynamicTool:
|
|||
start_time = time.time()
|
||||
top_k = top_k or self._config.top_k
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-05] Starting KB dynamic search: tenant={tenant_id}, "
|
||||
f"query={query[:50]}..., scene={scene}, top_k={top_k}"
|
||||
)
|
||||
effective_context = dict(context) if context else {}
|
||||
effective_scene = effective_context.get("scene", scene)
|
||||
|
||||
filter_result: FilterBuildResult | None = None
|
||||
|
||||
try:
|
||||
if self._filter_builder is None:
|
||||
self._filter_builder = MetadataFilterBuilder(self._session)
|
||||
|
||||
filter_result = await self._filter_builder.build_filter(
|
||||
tenant_id=tenant_id,
|
||||
context=context,
|
||||
# [Step-KB-Binding] 记录步骤知识库约束
|
||||
step_kb_binding_info: dict[str, Any] = {}
|
||||
if step_kb_config:
|
||||
step_kb_binding_info = {
|
||||
"step_id": step_kb_config.step_id,
|
||||
"allowed_kb_ids": step_kb_config.allowed_kb_ids,
|
||||
"preferred_kb_ids": step_kb_config.preferred_kb_ids,
|
||||
"kb_query_hint": step_kb_config.kb_query_hint,
|
||||
"max_kb_calls": step_kb_config.max_kb_calls,
|
||||
}
|
||||
logger.info(
|
||||
f"[Step-KB-Binding] Step KB config applied: "
|
||||
f"step_id={step_kb_config.step_id}, "
|
||||
f"allowed_kb_ids={step_kb_config.allowed_kb_ids}, "
|
||||
f"preferred_kb_ids={step_kb_config.preferred_kb_ids}"
|
||||
)
|
||||
|
||||
if filter_result.missing_required_slots:
|
||||
logger.warning(
|
||||
f"[AC-MARH-05] Missing required slots: "
|
||||
f"{filter_result.missing_required_slots}"
|
||||
)
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.info(
|
||||
f"[AC-MARH-05] 开始执行KB动态检索: tenant={tenant_id}, "
|
||||
f"query={query[:50]}..., scene={effective_scene}, top_k={top_k}, "
|
||||
f"slot_policy={slot_policy}, context_keys={list(effective_context.keys())}"
|
||||
)
|
||||
|
||||
tool_trace = ToolCallTrace(
|
||||
tool_name=self.name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.ERROR,
|
||||
error_code="MISSING_REQUIRED_SLOTS",
|
||||
args_digest=f"query={query[:50]}, scene={scene}",
|
||||
result_digest=f"missing={len(filter_result.missing_required_slots)}",
|
||||
# [AC-MRS-SLOT-META-02] 如果提供了 slot_state,优先使用
|
||||
if slot_state is not None:
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-META-02] Using provided slot_state: "
|
||||
f"filled={len(slot_state.filled_slots)}, "
|
||||
f"missing={len(slot_state.missing_required_slots)}"
|
||||
)
|
||||
|
||||
# 检查是否有缺失的必填槽位(仅在流程严格模式下阻断)
|
||||
if slot_state.missing_required_slots:
|
||||
if slot_policy == "flow_strict":
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-META-03] 流程严格模式命中缺失必填槽位,触发追问: "
|
||||
f"tenant={tenant_id}, missing={len(slot_state.missing_required_slots)}"
|
||||
)
|
||||
|
||||
tool_trace = ToolCallTrace(
|
||||
tool_name=self.name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.ERROR,
|
||||
error_code="MISSING_REQUIRED_SLOTS",
|
||||
args_digest=f"query={query[:50]}, scene={effective_scene}",
|
||||
result_digest=f"missing={len(slot_state.missing_required_slots)}",
|
||||
arguments={"query": query, "scene": effective_scene, "context": context},
|
||||
result={"missing_required_slots": slot_state.missing_required_slots},
|
||||
)
|
||||
|
||||
return KbSearchDynamicResult(
|
||||
success=False,
|
||||
applied_filter={},
|
||||
missing_required_slots=slot_state.missing_required_slots,
|
||||
filter_debug={
|
||||
"source": "slot_state",
|
||||
"filled_slots": slot_state.filled_slots,
|
||||
"slot_to_field_map": slot_state.slot_to_field_map,
|
||||
},
|
||||
fallback_reason_code="MISSING_REQUIRED_SLOTS",
|
||||
duration_ms=duration_ms,
|
||||
tool_trace=tool_trace,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-META-03] 通用问答宽松模式检测到缺失槽位但不阻断检索: "
|
||||
f"tenant={tenant_id}, missing={len(slot_state.missing_required_slots)}"
|
||||
)
|
||||
|
||||
return KbSearchDynamicResult(
|
||||
success=False,
|
||||
applied_filter=filter_result.applied_filter,
|
||||
missing_required_slots=filter_result.missing_required_slots,
|
||||
filter_debug=filter_result.debug_info,
|
||||
fallback_reason_code="MISSING_REQUIRED_SLOTS",
|
||||
duration_ms=duration_ms,
|
||||
tool_trace=tool_trace,
|
||||
)
|
||||
# 使用 slot_state 构建 filter
|
||||
metadata_filter, filter_sources = await self._build_filter_from_slot_state(
|
||||
tenant_id=tenant_id,
|
||||
slot_state=slot_state,
|
||||
context=effective_context,
|
||||
scene_slot_context=effective_context.get("scene_slot_context"), # [AC-SCENE-SLOT-02]
|
||||
)
|
||||
else:
|
||||
# 原有逻辑:构建元数据 filter
|
||||
# 如果 context 简单(只有键值对),直接构造 filter,跳过 MetadataFilterBuilder
|
||||
metadata_filter = await self._build_filter_legacy(
|
||||
tenant_id=tenant_id,
|
||||
context=effective_context,
|
||||
query=query,
|
||||
effective_scene=effective_scene,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
metadata_filter = filter_result.applied_filter if filter_result.success else None
|
||||
if isinstance(metadata_filter, KbSearchDynamicResult):
|
||||
# 有错误,直接返回
|
||||
return metadata_filter
|
||||
|
||||
try:
|
||||
hits = await self._retrieve_with_timeout(
|
||||
tenant_id=tenant_id,
|
||||
query=query,
|
||||
metadata_filter=metadata_filter,
|
||||
top_k=top_k,
|
||||
step_kb_config=step_kb_config,
|
||||
)
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
kb_hit = len(hits) > 0
|
||||
|
||||
# [Step-KB-Binding] 记录命中的知识库
|
||||
hit_kb_ids = list(set(hit.get("kb_id") for hit in hits if hit.get("kb_id")))
|
||||
if step_kb_binding_info:
|
||||
step_kb_binding_info["used_kb_ids"] = hit_kb_ids
|
||||
step_kb_binding_info["kb_hit"] = kb_hit
|
||||
|
||||
tool_trace = ToolCallTrace(
|
||||
tool_name=self.name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.OK,
|
||||
args_digest=f"query={query[:50]}, scene={scene}",
|
||||
args_digest=f"query={query[:50]}, scene={effective_scene}",
|
||||
result_digest=f"hits={len(hits)}",
|
||||
arguments={"query": query, "scene": effective_scene, "context": context},
|
||||
result={"hits_count": len(hits), "kb_hit": kb_hit},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-05] KB dynamic search completed: tenant={tenant_id}, "
|
||||
f"hits={len(hits)}, duration_ms={duration_ms}, kb_hit={kb_hit}"
|
||||
f"hits={len(hits)}, duration_ms={duration_ms}, kb_hit={kb_hit}, "
|
||||
f"hit_kb_ids={hit_kb_ids}"
|
||||
)
|
||||
|
||||
# 确定 filter 来源用于调试
|
||||
filter_source = "slot_state" if slot_state is not None else "builder"
|
||||
kb_filter_sources = filter_sources if slot_state is not None else {}
|
||||
|
||||
return KbSearchDynamicResult(
|
||||
success=True,
|
||||
hits=hits,
|
||||
applied_filter=filter_result.applied_filter if filter_result else {},
|
||||
filter_debug=filter_result.debug_info if filter_result else {},
|
||||
applied_filter=metadata_filter or {},
|
||||
filter_debug={"source": filter_source, "filter_sources": kb_filter_sources},
|
||||
duration_ms=duration_ms,
|
||||
tool_trace=tool_trace,
|
||||
step_kb_binding=step_kb_binding_info if step_kb_binding_info else None,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
|
|
@ -262,16 +482,18 @@ class KbSearchDynamicTool:
|
|||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.TIMEOUT,
|
||||
error_code="KB_TIMEOUT",
|
||||
arguments={"query": query, "scene": effective_scene, "context": context},
|
||||
)
|
||||
|
||||
return KbSearchDynamicResult(
|
||||
success=False,
|
||||
applied_filter=filter_result.applied_filter if filter_result else {},
|
||||
missing_required_slots=filter_result.missing_required_slots if filter_result else [],
|
||||
filter_debug=filter_result.debug_info if filter_result else {},
|
||||
applied_filter=metadata_filter or {},
|
||||
missing_required_slots=[],
|
||||
filter_debug={"error": "timeout"},
|
||||
fallback_reason_code="KB_TIMEOUT",
|
||||
duration_ms=duration_ms,
|
||||
tool_trace=tool_trace,
|
||||
step_kb_binding=step_kb_binding_info if step_kb_binding_info else None,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -287,31 +509,224 @@ class KbSearchDynamicTool:
|
|||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.ERROR,
|
||||
error_code="KB_ERROR",
|
||||
arguments={"query": query, "scene": effective_scene, "context": context},
|
||||
)
|
||||
|
||||
return KbSearchDynamicResult(
|
||||
success=False,
|
||||
applied_filter=filter_result.applied_filter if filter_result else {},
|
||||
missing_required_slots=filter_result.missing_required_slots if filter_result else [],
|
||||
applied_filter=metadata_filter or {},
|
||||
missing_required_slots=[],
|
||||
filter_debug={"error": str(e)},
|
||||
fallback_reason_code="KB_ERROR",
|
||||
duration_ms=duration_ms,
|
||||
tool_trace=tool_trace,
|
||||
)
|
||||
|
||||
async def _build_filter_legacy(
|
||||
self,
|
||||
tenant_id: str,
|
||||
context: dict[str, Any],
|
||||
query: str,
|
||||
effective_scene: str,
|
||||
start_time: float,
|
||||
) -> dict[str, Any] | KbSearchDynamicResult:
|
||||
"""
|
||||
[AC-MRS-SLOT-META-02] 原有逻辑:构建元数据 filter
|
||||
|
||||
Returns:
|
||||
dict: 构建成功的 filter
|
||||
KbSearchDynamicResult: 构建失败时的错误结果
|
||||
"""
|
||||
metadata_filter: dict[str, Any] | None = None
|
||||
|
||||
# 简单 context:直接构造 filter(信任 AI 传入的值)
|
||||
# 复杂场景:使用 MetadataFilterBuilder 进行严格验证
|
||||
is_simple_context = all(
|
||||
isinstance(v, (str, int, float, bool))
|
||||
for v in context.values()
|
||||
)
|
||||
|
||||
if is_simple_context:
|
||||
# 直接构造 filter,不查询数据库
|
||||
metadata_filter = context
|
||||
logger.info(
|
||||
f"[AC-MARH-05] Using simple context as filter directly: "
|
||||
f"{metadata_filter}"
|
||||
)
|
||||
else:
|
||||
# 复杂 context,使用 MetadataFilterBuilder
|
||||
filter_result = await self._build_filter_with_builder(
|
||||
tenant_id=tenant_id,
|
||||
context=context,
|
||||
query=query,
|
||||
effective_scene=effective_scene,
|
||||
start_time=start_time,
|
||||
)
|
||||
if isinstance(filter_result, KbSearchDynamicResult):
|
||||
# 有错误,直接返回
|
||||
return filter_result
|
||||
metadata_filter = filter_result
|
||||
|
||||
return metadata_filter or {}
|
||||
|
||||
async def _build_filter_from_slot_state(
|
||||
self,
|
||||
tenant_id: str,
|
||||
slot_state: SlotState,
|
||||
context: dict[str, Any],
|
||||
scene_slot_context: Any = None, # [AC-SCENE-SLOT-02] 场景槽位上下文
|
||||
) -> tuple[dict[str, Any], dict[str, str]]:
|
||||
"""
|
||||
[AC-MRS-SLOT-META-02] 基于槽位状态构建过滤条件
|
||||
[AC-SCENE-SLOT-02] 支持场景槽位包配置的优先级
|
||||
|
||||
过滤值来源优先级:
|
||||
1. 已确认槽位值(slot_state.filled_slots)
|
||||
2. 当前请求 context 显式值
|
||||
3. 元数据默认值
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
slot_state: 槽位状态
|
||||
context: 上下文
|
||||
scene_slot_context: 场景槽位上下文
|
||||
|
||||
Returns:
|
||||
(过滤条件字典, 过滤来源映射)
|
||||
"""
|
||||
if self._filter_builder is None:
|
||||
self._filter_builder = MetadataFilterBuilder(self._session)
|
||||
|
||||
# 获取可过滤字段定义
|
||||
filterable_fields = await self._filter_builder._get_filterable_fields(tenant_id)
|
||||
|
||||
applied_filter: dict[str, Any] = {}
|
||||
filter_debug_sources: dict[str, str] = {}
|
||||
|
||||
# [AC-SCENE-SLOT-02] 如果有场景槽位上下文,优先处理场景定义的槽位
|
||||
scene_slot_keys = set()
|
||||
if scene_slot_context:
|
||||
scene_slot_keys = set(scene_slot_context.get_all_slot_keys())
|
||||
logger.debug(
|
||||
f"[AC-SCENE-SLOT-02] Processing scene slots: "
|
||||
f"scene={scene_slot_context.scene_key}, slots={scene_slot_keys}"
|
||||
)
|
||||
|
||||
for field_info in filterable_fields:
|
||||
field_key = field_info.field_key
|
||||
value = None
|
||||
source = None
|
||||
|
||||
# 优先级 1: 已确认槽位值(通过 slot_to_field_map 映射)
|
||||
if slot_state.slot_to_field_map:
|
||||
# 查找哪个 slot 映射到这个 field
|
||||
for slot_key, mapped_field_key in slot_state.slot_to_field_map.items():
|
||||
if mapped_field_key == field_key and slot_key in slot_state.filled_slots:
|
||||
value = slot_state.filled_slots[slot_key]
|
||||
source = "slot"
|
||||
break
|
||||
|
||||
# 如果 slot 映射没有命中,直接检查 slot_key 是否等于 field_key
|
||||
if value is None and field_key in slot_state.filled_slots:
|
||||
value = slot_state.filled_slots[field_key]
|
||||
source = "slot"
|
||||
|
||||
# 优先级 2: 当前请求 context 显式值
|
||||
if value is None and field_key in context:
|
||||
value = context[field_key]
|
||||
source = "context"
|
||||
|
||||
# 优先级 3: 元数据默认值
|
||||
if value is None and field_info.default_value is not None:
|
||||
value = field_info.default_value
|
||||
source = "default"
|
||||
|
||||
# 构建过滤条件
|
||||
if value is not None:
|
||||
filter_value = self._filter_builder._build_field_filter(
|
||||
field_info, value
|
||||
)
|
||||
if filter_value is not None:
|
||||
applied_filter[field_key] = filter_value
|
||||
filter_debug_sources[field_key] = source
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-META-02] Filter built from slot_state: "
|
||||
f"fields={len(applied_filter)}, sources={filter_debug_sources}"
|
||||
)
|
||||
|
||||
return applied_filter, filter_debug_sources
|
||||
|
||||
async def _build_filter_with_builder(
|
||||
self,
|
||||
tenant_id: str,
|
||||
context: dict[str, Any],
|
||||
query: str,
|
||||
effective_scene: str,
|
||||
start_time: float,
|
||||
) -> dict[str, Any] | KbSearchDynamicResult:
|
||||
"""
|
||||
使用 MetadataFilterBuilder 构建 filter(复杂场景)。
|
||||
|
||||
Returns:
|
||||
dict: 构建成功的 filter
|
||||
KbSearchDynamicResult: 构建失败时的错误结果
|
||||
"""
|
||||
from app.services.mid.metadata_filter_builder import FilterBuildResult
|
||||
|
||||
if self._filter_builder is None:
|
||||
self._filter_builder = MetadataFilterBuilder(self._session)
|
||||
|
||||
filter_result = await self._filter_builder.build_filter(
|
||||
tenant_id=tenant_id,
|
||||
context=context,
|
||||
)
|
||||
|
||||
if filter_result.missing_required_slots:
|
||||
logger.warning(
|
||||
f"[AC-MARH-05] Missing required slots: "
|
||||
f"{filter_result.missing_required_slots}"
|
||||
)
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
tool_trace = ToolCallTrace(
|
||||
tool_name=self.name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.ERROR,
|
||||
error_code="MISSING_REQUIRED_SLOTS",
|
||||
args_digest=f"query={query[:50]}, scene={effective_scene}",
|
||||
result_digest=f"missing={len(filter_result.missing_required_slots)}",
|
||||
arguments={"query": query, "scene": effective_scene, "context": context},
|
||||
result={"missing_required_slots": filter_result.missing_required_slots},
|
||||
)
|
||||
|
||||
return KbSearchDynamicResult(
|
||||
success=False,
|
||||
applied_filter=filter_result.applied_filter,
|
||||
missing_required_slots=filter_result.missing_required_slots,
|
||||
filter_debug=filter_result.debug_info,
|
||||
fallback_reason_code="MISSING_REQUIRED_SLOTS",
|
||||
duration_ms=duration_ms,
|
||||
tool_trace=tool_trace,
|
||||
)
|
||||
|
||||
return filter_result.applied_filter if filter_result.success else {}
|
||||
|
||||
async def _retrieve_with_timeout(
|
||||
self,
|
||||
tenant_id: str,
|
||||
query: str,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
top_k: int = DEFAULT_TOP_K,
|
||||
step_kb_config: StepKbConfig | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""带超时控制的检索。"""
|
||||
timeout_seconds = self._config.timeout_ms / 1000.0
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self._do_retrieve(tenant_id, query, metadata_filter, top_k),
|
||||
self._do_retrieve(tenant_id, query, metadata_filter, top_k, step_kb_config),
|
||||
timeout=timeout_seconds,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
|
|
@ -323,18 +738,36 @@ class KbSearchDynamicTool:
|
|||
query: str,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
top_k: int = DEFAULT_TOP_K,
|
||||
step_kb_config: StepKbConfig | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""执行实际检索。"""
|
||||
"""执行实际检索。[Step-KB-Binding] 支持步骤级别的知识库约束。"""
|
||||
if self._vector_retriever is None:
|
||||
from app.services.retrieval.vector_retriever import get_vector_retriever
|
||||
self._vector_retriever = await get_vector_retriever()
|
||||
|
||||
from app.services.retrieval.base import RetrievalContext
|
||||
|
||||
# [Step-KB-Binding] 确定要检索的知识库范围
|
||||
kb_ids = None
|
||||
if step_kb_config:
|
||||
# 如果配置了 allowed_kb_ids,则只检索这些知识库
|
||||
if step_kb_config.allowed_kb_ids:
|
||||
kb_ids = step_kb_config.allowed_kb_ids
|
||||
logger.info(
|
||||
f"[Step-KB-Binding] Restricting KB search to: {kb_ids}"
|
||||
)
|
||||
# 如果只配置了 preferred_kb_ids,优先检索这些知识库
|
||||
elif step_kb_config.preferred_kb_ids:
|
||||
kb_ids = step_kb_config.preferred_kb_ids
|
||||
logger.info(
|
||||
f"[Step-KB-Binding] Preferring KB search in: {kb_ids}"
|
||||
)
|
||||
|
||||
ctx = RetrievalContext(
|
||||
tenant_id=tenant_id,
|
||||
query=query,
|
||||
metadata=metadata_filter,
|
||||
metadata_filter=metadata_filter,
|
||||
kb_ids=kb_ids,
|
||||
)
|
||||
|
||||
result = await self._vector_retriever.retrieve(ctx)
|
||||
|
|
@ -347,6 +780,7 @@ class KbSearchDynamicTool:
|
|||
"content": hit.text,
|
||||
"score": hit.score,
|
||||
"metadata": hit.metadata,
|
||||
"kb_id": hit.metadata.get("kb_id"),
|
||||
})
|
||||
|
||||
return hits[:top_k]
|
||||
|
|
@ -384,6 +818,7 @@ async def create_kb_search_dynamic_handler(
|
|||
scene: str = "open_consult",
|
||||
top_k: int = DEFAULT_TOP_K,
|
||||
context: dict[str, Any] | None = None,
|
||||
**kwargs, # 接受系统注入的额外参数(user_id, session_id 等)
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
KB 动态检索 handler。
|
||||
|
|
@ -391,9 +826,9 @@ async def create_kb_search_dynamic_handler(
|
|||
Args:
|
||||
query: 检索查询
|
||||
tenant_id: 租户 ID
|
||||
scene: 场景标识
|
||||
scene: 场景标识(默认值,会被 context 中的 scene 覆盖)
|
||||
top_k: 返回数量
|
||||
context: 上下文
|
||||
context: 上下文(包含 scene 等过滤字段)
|
||||
|
||||
Returns:
|
||||
检索结果字典
|
||||
|
|
@ -442,6 +877,7 @@ def register_kb_search_dynamic_tool(
|
|||
scene: str = "open_consult",
|
||||
top_k: int = DEFAULT_TOP_K,
|
||||
context: dict[str, Any] | None = None,
|
||||
**kwargs, # 接受系统注入的额外参数(user_id, session_id 等)
|
||||
) -> dict[str, Any]:
|
||||
tool = KbSearchDynamicTool(
|
||||
session=session,
|
||||
|
|
|
|||
|
|
@ -542,9 +542,9 @@ def register_memory_recall_tool(
|
|||
cfg = config or MemoryRecallConfig()
|
||||
|
||||
async def memory_recall_handler(
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
tenant_id: str = "",
|
||||
user_id: str = "",
|
||||
session_id: str = "",
|
||||
recall_scope: list[str] | None = None,
|
||||
max_recent_messages: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
|
|
@ -554,6 +554,7 @@ def register_memory_recall_tool(
|
|||
timeout_governor=timeout_governor,
|
||||
config=cfg,
|
||||
)
|
||||
|
||||
result = await tool.execute(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
|
|
@ -587,12 +588,9 @@ def register_memory_recall_tool(
|
|||
"recall_scope": {"type": "array", "description": "召回范围,例如 profile/facts/preferences/summary/slots"},
|
||||
"max_recent_messages": {"type": "integer", "description": "历史回填窗口大小"}
|
||||
},
|
||||
"required": ["tenant_id", "user_id", "session_id"]
|
||||
"required": []
|
||||
},
|
||||
"example_action_input": {
|
||||
"tenant_id": "default",
|
||||
"user_id": "u_10086",
|
||||
"session_id": "s_abc_001",
|
||||
"recall_scope": ["profile", "facts", "preferences", "summary", "slots"],
|
||||
"max_recent_messages": 8
|
||||
},
|
||||
|
|
|
|||
|
|
@ -165,21 +165,56 @@ class MetadataFilterBuilder:
|
|||
) -> list[FilterFieldInfo]:
|
||||
"""
|
||||
[AC-MRS-11] 获取可过滤的字段定义。
|
||||
优先从 Redis 缓存获取,未缓存则从数据库查询并缓存。
|
||||
|
||||
条件:
|
||||
- 状态=生效 (active)
|
||||
- field_roles 包含 resource_filter
|
||||
"""
|
||||
import time
|
||||
start_time = time.time()
|
||||
|
||||
# 1. 尝试从缓存获取
|
||||
from app.services.metadata_cache_service import get_metadata_cache_service
|
||||
cache_service = await get_metadata_cache_service()
|
||||
cached_fields = await cache_service.get_fields(tenant_id)
|
||||
|
||||
if cached_fields is not None:
|
||||
# 缓存命中,直接返回
|
||||
logger.info(
|
||||
f"[AC-MRS-11] Cache hit: Retrieved {len(cached_fields)} fields "
|
||||
f"for tenant={tenant_id} in {(time.time() - start_time)*1000:.2f}ms"
|
||||
)
|
||||
return [
|
||||
FilterFieldInfo(
|
||||
field_key=f["field_key"],
|
||||
label=f["label"],
|
||||
field_type=f["field_type"],
|
||||
required=f["required"],
|
||||
options=f.get("options"),
|
||||
default_value=f.get("default_value"),
|
||||
is_filterable=f["is_filterable"],
|
||||
)
|
||||
for f in cached_fields
|
||||
]
|
||||
|
||||
# 2. 缓存未命中,从数据库查询
|
||||
logger.info(f"[AC-MRS-11] Cache miss: Querying database for tenant={tenant_id}")
|
||||
db_start = time.time()
|
||||
|
||||
fields = await self._role_provider.get_fields_by_role(
|
||||
tenant_id=tenant_id,
|
||||
role=FieldRole.RESOURCE_FILTER.value,
|
||||
)
|
||||
|
||||
db_time = (time.time() - db_start) * 1000
|
||||
logger.info(
|
||||
f"[AC-MRS-11] Retrieved {len(fields)} resource_filter fields for tenant={tenant_id}"
|
||||
f"[AC-MRS-11] Retrieved {len(fields)} resource_filter fields "
|
||||
f"for tenant={tenant_id} from DB in {db_time:.2f}ms"
|
||||
)
|
||||
|
||||
return [
|
||||
# 3. 转换为 FilterFieldInfo 列表
|
||||
filter_fields = [
|
||||
FilterFieldInfo(
|
||||
field_key=f.field_key,
|
||||
label=f.label,
|
||||
|
|
@ -192,6 +227,29 @@ class MetadataFilterBuilder:
|
|||
for f in fields
|
||||
]
|
||||
|
||||
# 4. 缓存到 Redis
|
||||
cache_data = [
|
||||
{
|
||||
"field_key": f.field_key,
|
||||
"label": f.label,
|
||||
"field_type": f.field_type,
|
||||
"required": f.required,
|
||||
"options": f.options,
|
||||
"default_value": f.default_value,
|
||||
"is_filterable": f.is_filterable,
|
||||
}
|
||||
for f in filter_fields
|
||||
]
|
||||
await cache_service.set_fields(tenant_id, cache_data)
|
||||
|
||||
total_time = (time.time() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[AC-MRS-11] Total time for tenant={tenant_id}: {total_time:.2f}ms "
|
||||
f"(DB: {db_time:.2f}ms)"
|
||||
)
|
||||
|
||||
return filter_fields
|
||||
|
||||
def _build_field_filter(
|
||||
self,
|
||||
field_info: FilterFieldInfo,
|
||||
|
|
|
|||
|
|
@ -175,7 +175,9 @@ class RoleBasedFieldProvider:
|
|||
"slot_key": slot.slot_key,
|
||||
"type": slot.type,
|
||||
"required": slot.required,
|
||||
# [AC-MRS-07-UPGRADE] 返回新旧字段
|
||||
"extract_strategy": slot.extract_strategy,
|
||||
"extract_strategies": slot.extract_strategies,
|
||||
"validation_rule": slot.validation_rule,
|
||||
"ask_back_prompt": slot.ask_back_prompt,
|
||||
"default_value": slot.default_value,
|
||||
|
|
@ -217,7 +219,9 @@ class RoleBasedFieldProvider:
|
|||
"slot_key": field.field_key,
|
||||
"type": field.type,
|
||||
"required": field.required,
|
||||
# [AC-MRS-07-UPGRADE] 返回新旧字段
|
||||
"extract_strategy": None,
|
||||
"extract_strategies": None,
|
||||
"validation_rule": None,
|
||||
"ask_back_prompt": None,
|
||||
"default_value": field.default_value,
|
||||
|
|
|
|||
Loading…
Reference in New Issue