ai-robot-core/ai-service/app/services/retrieval/mode_router.py

439 lines
14 KiB
Python
Raw Normal View History

"""
Mode Router for RAG Runtime Mode Selection.
[AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11] Routes to direct/react/auto mode.
Mode Descriptions:
- direct: Low-latency generic retrieval path (single KB call)
- react: Multi-step ReAct retrieval path (high accuracy)
- auto: Automatic selection based on complexity/confidence rules
"""
from __future__ import annotations
import logging
import re
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from app.services.retrieval.routing_config import (
RagRuntimeMode,
RoutingConfig,
StrategyContext,
)
if TYPE_CHECKING:
from app.services.retrieval.base import RetrievalResult
logger = logging.getLogger(__name__)
@dataclass
class ComplexityAnalyzer:
"""
Analyzes query complexity for mode routing decisions.
Complexity factors:
- Query length
- Number of conditions/constraints
- Presence of logical operators (and, or, not)
- Cross-domain indicators
- Multi-step reasoning requirements
"""
short_query_threshold: int = 20
long_query_threshold: int = 100
condition_patterns: list[str] = field(default_factory=lambda: [
r"和|与|及|并且|同时",
r"或者|还是|要么",
r"但是|不过|然而",
r"如果|假如|假设",
r"既.*又",
r"不仅.*而且",
])
reasoning_patterns: list[str] = field(default_factory=lambda: [
r"为什么|原因|理由",
r"怎么|如何|怎样",
r"区别|差异|不同",
r"比较|对比|优劣",
r"分析|评估|判断",
])
cross_domain_patterns: list[str] = field(default_factory=lambda: [
r"跨|多|各个",
r"所有|全部|整体",
r"综合|汇总|统计",
])
def analyze(self, query: str) -> float:
"""
Analyze query complexity and return a score (0.0 ~ 1.0).
Higher score indicates more complex query that may benefit from ReAct mode.
Args:
query: User query text
Returns:
Complexity score (0.0 = simple, 1.0 = very complex)
"""
if not query:
return 0.0
score = 0.0
query_length = len(query)
if query_length < self.short_query_threshold:
score += 0.0
elif query_length > self.long_query_threshold:
score += 0.3
else:
score += 0.15
condition_count = 0
for pattern in self.condition_patterns:
matches = re.findall(pattern, query)
condition_count += len(matches)
if condition_count >= 3:
score += 0.3
elif condition_count >= 2:
score += 0.2
elif condition_count >= 1:
score += 0.1
for pattern in self.reasoning_patterns:
if re.search(pattern, query):
score += 0.15
break
for pattern in self.cross_domain_patterns:
if re.search(pattern, query):
score += 0.15
break
question_marks = query.count("?") + query.count("")
if question_marks >= 2:
score += 0.1
return min(1.0, score)
@dataclass
class ModeRouteResult:
"""Result from mode routing decision."""
mode: RagRuntimeMode
confidence: float
complexity_score: float
should_fallback_to_react: bool = False
fallback_reason: str | None = None
diagnostics: dict[str, Any] = field(default_factory=dict)
class DirectRetrievalExecutor:
"""
[AC-AISVC-RES-09] Direct retrieval executor for low-latency path.
Single KB call without multi-step reasoning.
"""
def __init__(self):
self._retriever = None
async def execute(
self,
ctx: StrategyContext,
) -> "RetrievalResult":
"""
Execute direct retrieval (single KB call).
"""
from app.services.retrieval.optimized_retriever import get_optimized_retriever
from app.services.retrieval.base import RetrievalContext
if self._retriever is None:
self._retriever = await get_optimized_retriever()
retrieval_ctx = RetrievalContext(
tenant_id=ctx.tenant_id,
query=ctx.query,
metadata_filter=ctx.metadata_filter,
kb_ids=ctx.kb_ids,
)
return await self._retriever.retrieve(retrieval_ctx)
class ReactRetrievalExecutor:
"""
[AC-AISVC-RES-10] ReAct retrieval executor for multi-step path.
Uses AgentOrchestrator for multi-step reasoning and KB calls.
"""
def __init__(self, max_steps: int = 5):
self._max_steps = max_steps
async def execute(
self,
ctx: StrategyContext,
config: RoutingConfig,
) -> tuple[str, "RetrievalResult | None", dict[str, Any]]:
"""
Execute ReAct retrieval (multi-step reasoning).
Returns:
Tuple of (final_answer, retrieval_result, react_context)
"""
from app.services.mid.agent_orchestrator import AgentOrchestrator, AgentMode
from app.services.mid.tool_registry import ToolRegistry
from app.services.mid.timeout_governor import TimeoutGovernor
from app.services.llm.factory import get_llm_config_manager
try:
llm_manager = get_llm_config_manager()
llm_client = llm_manager.get_client()
except Exception as e:
logger.warning(f"[ModeRouter] Failed to get LLM client: {e}")
llm_client = None
tool_registry = ToolRegistry(timeout_governor=TimeoutGovernor())
timeout_governor = TimeoutGovernor()
orchestrator = AgentOrchestrator(
max_iterations=min(config.react_max_steps, self._max_steps),
timeout_governor=timeout_governor,
llm_client=llm_client,
tool_registry=tool_registry,
tenant_id=ctx.tenant_id,
mode=AgentMode.FUNCTION_CALLING,
)
base_context = {
"query": ctx.query,
"metadata_filter": ctx.metadata_filter,
"kb_ids": ctx.kb_ids,
**ctx.additional_context,
}
final_answer, react_ctx, trace = await orchestrator.execute(
user_message=ctx.query,
context=base_context,
)
return final_answer, None, {
"iterations": react_ctx.iteration,
"tool_calls": [tc.model_dump() for tc in react_ctx.tool_calls] if react_ctx.tool_calls else [],
"final_answer": final_answer,
}
class ModeRouter:
"""
[AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11]
Mode router for RAG runtime mode selection.
Mode Selection:
- direct: Low-latency generic retrieval (single KB call)
- react: Multi-step ReAct retrieval (high accuracy)
- auto: Automatic selection based on complexity/confidence
Auto Mode Rules:
- Direct conditions:
- Short query, clear intent
- High metadata confidence
- No cross-domain/multi-condition
- React conditions:
- Multi-condition/multi-constraint
- Low metadata confidence
- Need for secondary confirmation or multi-step reasoning
"""
def __init__(
self,
config: RoutingConfig | None = None,
):
self._config = config or RoutingConfig()
self._complexity_analyzer = ComplexityAnalyzer()
self._direct_executor = DirectRetrievalExecutor()
self._react_executor = ReactRetrievalExecutor(
max_steps=self._config.react_max_steps
)
@property
def config(self) -> RoutingConfig:
"""Get current configuration."""
return self._config
def update_config(self, new_config: RoutingConfig) -> None:
"""
[AC-AISVC-RES-15] Update routing configuration (hot reload).
"""
self._config = new_config
self._react_executor._max_steps = new_config.react_max_steps
logger.info(
f"[AC-AISVC-RES-15] ModeRouter config updated: "
f"mode={new_config.rag_runtime_mode.value}, "
f"react_max_steps={new_config.react_max_steps}, "
f"confidence_threshold={new_config.react_trigger_confidence_threshold}"
)
def route(
self,
ctx: StrategyContext,
) -> ModeRouteResult:
"""
[AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11]
Route to appropriate mode based on configuration and context.
Args:
ctx: Strategy context with query, metadata, confidence, etc.
Returns:
ModeRouteResult with selected mode and diagnostics
"""
configured_mode = self._config.get_rag_runtime_mode()
if configured_mode == RagRuntimeMode.DIRECT:
logger.info(
f"[AC-AISVC-RES-09] Mode routing to DIRECT: tenant={ctx.tenant_id}"
)
return ModeRouteResult(
mode=RagRuntimeMode.DIRECT,
confidence=ctx.metadata_confidence,
complexity_score=ctx.complexity_score,
diagnostics={"configured_mode": "direct"},
)
if configured_mode == RagRuntimeMode.REACT:
logger.info(
f"[AC-AISVC-RES-10] Mode routing to REACT: tenant={ctx.tenant_id}"
)
return ModeRouteResult(
mode=RagRuntimeMode.REACT,
confidence=ctx.metadata_confidence,
complexity_score=ctx.complexity_score,
diagnostics={"configured_mode": "react"},
)
complexity_score = self._complexity_analyzer.analyze(ctx.query)
effective_complexity = max(complexity_score, ctx.complexity_score)
should_use_react = self._config.should_trigger_react_in_auto_mode(
confidence=ctx.metadata_confidence,
complexity_score=effective_complexity,
)
selected_mode = RagRuntimeMode.REACT if should_use_react else RagRuntimeMode.DIRECT
logger.info(
f"[AC-AISVC-RES-11, AC-AISVC-RES-12, AC-AISVC-RES-13] "
f"Auto mode routing: selected={selected_mode.value}, "
f"confidence={ctx.metadata_confidence:.2f}, "
f"complexity={effective_complexity:.2f}, "
f"tenant={ctx.tenant_id}"
)
return ModeRouteResult(
mode=selected_mode,
confidence=ctx.metadata_confidence,
complexity_score=effective_complexity,
diagnostics={
"configured_mode": "auto",
"analyzed_complexity": complexity_score,
"provided_complexity": ctx.complexity_score,
"react_trigger_confidence": self._config.react_trigger_confidence_threshold,
"react_trigger_complexity": self._config.react_trigger_complexity_score,
},
)
async def execute_direct(
self,
ctx: StrategyContext,
) -> "RetrievalResult":
"""
Execute direct retrieval mode.
"""
return await self._direct_executor.execute(ctx)
async def execute_react(
self,
ctx: StrategyContext,
) -> tuple[str, "RetrievalResult | None", dict[str, Any]]:
"""
Execute ReAct retrieval mode.
"""
return await self._react_executor.execute(ctx, self._config)
async def execute_with_fallback(
self,
ctx: StrategyContext,
) -> tuple["RetrievalResult | None", str | None, ModeRouteResult]:
"""
[AC-AISVC-RES-14] Execute with fallback from direct to react on low confidence.
Args:
ctx: Strategy context
Returns:
Tuple of (RetrievalResult or None, final_answer or None, ModeRouteResult)
"""
route_result = self.route(ctx)
if route_result.mode == RagRuntimeMode.DIRECT:
retrieval_result = await self._direct_executor.execute(ctx)
max_score = 0.0
if retrieval_result and retrieval_result.hits:
max_score = max((h.score for h in retrieval_result.hits), default=0.0)
if self._config.should_fallback_direct_to_react(max_score):
logger.info(
f"[AC-AISVC-RES-14] Direct mode low confidence fallback to react: "
f"confidence={max_score:.2f}, threshold={self._config.direct_fallback_confidence_threshold}"
)
final_answer, _, react_ctx = await self._react_executor.execute(ctx, self._config)
return (
None,
final_answer,
ModeRouteResult(
mode=RagRuntimeMode.REACT,
confidence=max_score,
complexity_score=route_result.complexity_score,
should_fallback_to_react=True,
fallback_reason="low_confidence",
diagnostics={
**route_result.diagnostics,
"fallback_from": "direct",
"direct_confidence": max_score,
},
),
)
return retrieval_result, None, route_result
final_answer, _, react_ctx = await self._react_executor.execute(ctx, self._config)
return None, final_answer, route_result
_mode_router: ModeRouter | None = None
def get_mode_router() -> ModeRouter:
"""Get or create ModeRouter singleton."""
global _mode_router
if _mode_router is None:
_mode_router = ModeRouter()
return _mode_router
def reset_mode_router() -> None:
"""Reset ModeRouter singleton (for testing)."""
global _mode_router
_mode_router = None