ai-robot-core/ai-service/app/services/retrieval/mode_router.py

439 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Mode Router for RAG Runtime Mode Selection.
[AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11] Routes to direct/react/auto mode.
Mode Descriptions:
- direct: Low-latency generic retrieval path (single KB call)
- react: Multi-step ReAct retrieval path (high accuracy)
- auto: Automatic selection based on complexity/confidence rules
"""
from __future__ import annotations
import logging
import re
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from app.services.retrieval.routing_config import (
RagRuntimeMode,
RoutingConfig,
StrategyContext,
)
if TYPE_CHECKING:
from app.services.retrieval.base import RetrievalResult
logger = logging.getLogger(__name__)
@dataclass
class ComplexityAnalyzer:
"""
Analyzes query complexity for mode routing decisions.
Complexity factors:
- Query length
- Number of conditions/constraints
- Presence of logical operators (and, or, not)
- Cross-domain indicators
- Multi-step reasoning requirements
"""
short_query_threshold: int = 20
long_query_threshold: int = 100
condition_patterns: list[str] = field(default_factory=lambda: [
r"和|与|及|并且|同时",
r"或者|还是|要么",
r"但是|不过|然而",
r"如果|假如|假设",
r"既.*又",
r"不仅.*而且",
])
reasoning_patterns: list[str] = field(default_factory=lambda: [
r"为什么|原因|理由",
r"怎么|如何|怎样",
r"区别|差异|不同",
r"比较|对比|优劣",
r"分析|评估|判断",
])
cross_domain_patterns: list[str] = field(default_factory=lambda: [
r"跨|多|各个",
r"所有|全部|整体",
r"综合|汇总|统计",
])
def analyze(self, query: str) -> float:
"""
Analyze query complexity and return a score (0.0 ~ 1.0).
Higher score indicates more complex query that may benefit from ReAct mode.
Args:
query: User query text
Returns:
Complexity score (0.0 = simple, 1.0 = very complex)
"""
if not query:
return 0.0
score = 0.0
query_length = len(query)
if query_length < self.short_query_threshold:
score += 0.0
elif query_length > self.long_query_threshold:
score += 0.3
else:
score += 0.15
condition_count = 0
for pattern in self.condition_patterns:
matches = re.findall(pattern, query)
condition_count += len(matches)
if condition_count >= 3:
score += 0.3
elif condition_count >= 2:
score += 0.2
elif condition_count >= 1:
score += 0.1
for pattern in self.reasoning_patterns:
if re.search(pattern, query):
score += 0.15
break
for pattern in self.cross_domain_patterns:
if re.search(pattern, query):
score += 0.15
break
question_marks = query.count("?") + query.count("")
if question_marks >= 2:
score += 0.1
return min(1.0, score)
@dataclass
class ModeRouteResult:
"""Result from mode routing decision."""
mode: RagRuntimeMode
confidence: float
complexity_score: float
should_fallback_to_react: bool = False
fallback_reason: str | None = None
diagnostics: dict[str, Any] = field(default_factory=dict)
class DirectRetrievalExecutor:
"""
[AC-AISVC-RES-09] Direct retrieval executor for low-latency path.
Single KB call without multi-step reasoning.
"""
def __init__(self):
self._retriever = None
async def execute(
self,
ctx: StrategyContext,
) -> "RetrievalResult":
"""
Execute direct retrieval (single KB call).
"""
from app.services.retrieval.optimized_retriever import get_optimized_retriever
from app.services.retrieval.base import RetrievalContext
if self._retriever is None:
self._retriever = await get_optimized_retriever()
retrieval_ctx = RetrievalContext(
tenant_id=ctx.tenant_id,
query=ctx.query,
metadata_filter=ctx.metadata_filter,
kb_ids=ctx.kb_ids,
)
return await self._retriever.retrieve(retrieval_ctx)
class ReactRetrievalExecutor:
"""
[AC-AISVC-RES-10] ReAct retrieval executor for multi-step path.
Uses AgentOrchestrator for multi-step reasoning and KB calls.
"""
def __init__(self, max_steps: int = 5):
self._max_steps = max_steps
async def execute(
self,
ctx: StrategyContext,
config: RoutingConfig,
) -> tuple[str, "RetrievalResult | None", dict[str, Any]]:
"""
Execute ReAct retrieval (multi-step reasoning).
Returns:
Tuple of (final_answer, retrieval_result, react_context)
"""
from app.services.mid.agent_orchestrator import AgentOrchestrator, AgentMode
from app.services.mid.tool_registry import ToolRegistry
from app.services.mid.timeout_governor import TimeoutGovernor
from app.services.llm.factory import get_llm_config_manager
try:
llm_manager = get_llm_config_manager()
llm_client = llm_manager.get_client()
except Exception as e:
logger.warning(f"[ModeRouter] Failed to get LLM client: {e}")
llm_client = None
tool_registry = ToolRegistry(timeout_governor=TimeoutGovernor())
timeout_governor = TimeoutGovernor()
orchestrator = AgentOrchestrator(
max_iterations=min(config.react_max_steps, self._max_steps),
timeout_governor=timeout_governor,
llm_client=llm_client,
tool_registry=tool_registry,
tenant_id=ctx.tenant_id,
mode=AgentMode.FUNCTION_CALLING,
)
base_context = {
"query": ctx.query,
"metadata_filter": ctx.metadata_filter,
"kb_ids": ctx.kb_ids,
**ctx.additional_context,
}
final_answer, react_ctx, trace = await orchestrator.execute(
user_message=ctx.query,
context=base_context,
)
return final_answer, None, {
"iterations": react_ctx.iteration,
"tool_calls": [tc.model_dump() for tc in react_ctx.tool_calls] if react_ctx.tool_calls else [],
"final_answer": final_answer,
}
class ModeRouter:
"""
[AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11]
Mode router for RAG runtime mode selection.
Mode Selection:
- direct: Low-latency generic retrieval (single KB call)
- react: Multi-step ReAct retrieval (high accuracy)
- auto: Automatic selection based on complexity/confidence
Auto Mode Rules:
- Direct conditions:
- Short query, clear intent
- High metadata confidence
- No cross-domain/multi-condition
- React conditions:
- Multi-condition/multi-constraint
- Low metadata confidence
- Need for secondary confirmation or multi-step reasoning
"""
def __init__(
self,
config: RoutingConfig | None = None,
):
self._config = config or RoutingConfig()
self._complexity_analyzer = ComplexityAnalyzer()
self._direct_executor = DirectRetrievalExecutor()
self._react_executor = ReactRetrievalExecutor(
max_steps=self._config.react_max_steps
)
@property
def config(self) -> RoutingConfig:
"""Get current configuration."""
return self._config
def update_config(self, new_config: RoutingConfig) -> None:
"""
[AC-AISVC-RES-15] Update routing configuration (hot reload).
"""
self._config = new_config
self._react_executor._max_steps = new_config.react_max_steps
logger.info(
f"[AC-AISVC-RES-15] ModeRouter config updated: "
f"mode={new_config.rag_runtime_mode.value}, "
f"react_max_steps={new_config.react_max_steps}, "
f"confidence_threshold={new_config.react_trigger_confidence_threshold}"
)
def route(
self,
ctx: StrategyContext,
) -> ModeRouteResult:
"""
[AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11]
Route to appropriate mode based on configuration and context.
Args:
ctx: Strategy context with query, metadata, confidence, etc.
Returns:
ModeRouteResult with selected mode and diagnostics
"""
configured_mode = self._config.get_rag_runtime_mode()
if configured_mode == RagRuntimeMode.DIRECT:
logger.info(
f"[AC-AISVC-RES-09] Mode routing to DIRECT: tenant={ctx.tenant_id}"
)
return ModeRouteResult(
mode=RagRuntimeMode.DIRECT,
confidence=ctx.metadata_confidence,
complexity_score=ctx.complexity_score,
diagnostics={"configured_mode": "direct"},
)
if configured_mode == RagRuntimeMode.REACT:
logger.info(
f"[AC-AISVC-RES-10] Mode routing to REACT: tenant={ctx.tenant_id}"
)
return ModeRouteResult(
mode=RagRuntimeMode.REACT,
confidence=ctx.metadata_confidence,
complexity_score=ctx.complexity_score,
diagnostics={"configured_mode": "react"},
)
complexity_score = self._complexity_analyzer.analyze(ctx.query)
effective_complexity = max(complexity_score, ctx.complexity_score)
should_use_react = self._config.should_trigger_react_in_auto_mode(
confidence=ctx.metadata_confidence,
complexity_score=effective_complexity,
)
selected_mode = RagRuntimeMode.REACT if should_use_react else RagRuntimeMode.DIRECT
logger.info(
f"[AC-AISVC-RES-11, AC-AISVC-RES-12, AC-AISVC-RES-13] "
f"Auto mode routing: selected={selected_mode.value}, "
f"confidence={ctx.metadata_confidence:.2f}, "
f"complexity={effective_complexity:.2f}, "
f"tenant={ctx.tenant_id}"
)
return ModeRouteResult(
mode=selected_mode,
confidence=ctx.metadata_confidence,
complexity_score=effective_complexity,
diagnostics={
"configured_mode": "auto",
"analyzed_complexity": complexity_score,
"provided_complexity": ctx.complexity_score,
"react_trigger_confidence": self._config.react_trigger_confidence_threshold,
"react_trigger_complexity": self._config.react_trigger_complexity_score,
},
)
async def execute_direct(
self,
ctx: StrategyContext,
) -> "RetrievalResult":
"""
Execute direct retrieval mode.
"""
return await self._direct_executor.execute(ctx)
async def execute_react(
self,
ctx: StrategyContext,
) -> tuple[str, "RetrievalResult | None", dict[str, Any]]:
"""
Execute ReAct retrieval mode.
"""
return await self._react_executor.execute(ctx, self._config)
async def execute_with_fallback(
self,
ctx: StrategyContext,
) -> tuple["RetrievalResult | None", str | None, ModeRouteResult]:
"""
[AC-AISVC-RES-14] Execute with fallback from direct to react on low confidence.
Args:
ctx: Strategy context
Returns:
Tuple of (RetrievalResult or None, final_answer or None, ModeRouteResult)
"""
route_result = self.route(ctx)
if route_result.mode == RagRuntimeMode.DIRECT:
retrieval_result = await self._direct_executor.execute(ctx)
max_score = 0.0
if retrieval_result and retrieval_result.hits:
max_score = max((h.score for h in retrieval_result.hits), default=0.0)
if self._config.should_fallback_direct_to_react(max_score):
logger.info(
f"[AC-AISVC-RES-14] Direct mode low confidence fallback to react: "
f"confidence={max_score:.2f}, threshold={self._config.direct_fallback_confidence_threshold}"
)
final_answer, _, react_ctx = await self._react_executor.execute(ctx, self._config)
return (
None,
final_answer,
ModeRouteResult(
mode=RagRuntimeMode.REACT,
confidence=max_score,
complexity_score=route_result.complexity_score,
should_fallback_to_react=True,
fallback_reason="low_confidence",
diagnostics={
**route_result.diagnostics,
"fallback_from": "direct",
"direct_confidence": max_score,
},
),
)
return retrieval_result, None, route_result
final_answer, _, react_ctx = await self._react_executor.execute(ctx, self._config)
return None, final_answer, route_result
_mode_router: ModeRouter | None = None
def get_mode_router() -> ModeRouter:
"""Get or create ModeRouter singleton."""
global _mode_router
if _mode_router is None:
_mode_router = ModeRouter()
return _mode_router
def reset_mode_router() -> None:
"""Reset ModeRouter singleton (for testing)."""
global _mode_router
_mode_router = None