439 lines
14 KiB
Python
439 lines
14 KiB
Python
"""
|
||
Mode Router for RAG Runtime Mode Selection.
|
||
[AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11] Routes to direct/react/auto mode.
|
||
|
||
Mode Descriptions:
|
||
- direct: Low-latency generic retrieval path (single KB call)
|
||
- react: Multi-step ReAct retrieval path (high accuracy)
|
||
- auto: Automatic selection based on complexity/confidence rules
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import logging
|
||
import re
|
||
import time
|
||
from dataclasses import dataclass, field
|
||
from typing import TYPE_CHECKING, Any
|
||
|
||
from app.services.retrieval.routing_config import (
|
||
RagRuntimeMode,
|
||
RoutingConfig,
|
||
StrategyContext,
|
||
)
|
||
|
||
if TYPE_CHECKING:
|
||
from app.services.retrieval.base import RetrievalResult
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
@dataclass
|
||
class ComplexityAnalyzer:
|
||
"""
|
||
Analyzes query complexity for mode routing decisions.
|
||
|
||
Complexity factors:
|
||
- Query length
|
||
- Number of conditions/constraints
|
||
- Presence of logical operators (and, or, not)
|
||
- Cross-domain indicators
|
||
- Multi-step reasoning requirements
|
||
"""
|
||
|
||
short_query_threshold: int = 20
|
||
long_query_threshold: int = 100
|
||
|
||
condition_patterns: list[str] = field(default_factory=lambda: [
|
||
r"和|与|及|并且|同时",
|
||
r"或者|还是|要么",
|
||
r"但是|不过|然而",
|
||
r"如果|假如|假设",
|
||
r"既.*又",
|
||
r"不仅.*而且",
|
||
])
|
||
|
||
reasoning_patterns: list[str] = field(default_factory=lambda: [
|
||
r"为什么|原因|理由",
|
||
r"怎么|如何|怎样",
|
||
r"区别|差异|不同",
|
||
r"比较|对比|优劣",
|
||
r"分析|评估|判断",
|
||
])
|
||
|
||
cross_domain_patterns: list[str] = field(default_factory=lambda: [
|
||
r"跨|多|各个",
|
||
r"所有|全部|整体",
|
||
r"综合|汇总|统计",
|
||
])
|
||
|
||
def analyze(self, query: str) -> float:
|
||
"""
|
||
Analyze query complexity and return a score (0.0 ~ 1.0).
|
||
|
||
Higher score indicates more complex query that may benefit from ReAct mode.
|
||
|
||
Args:
|
||
query: User query text
|
||
|
||
Returns:
|
||
Complexity score (0.0 = simple, 1.0 = very complex)
|
||
"""
|
||
if not query:
|
||
return 0.0
|
||
|
||
score = 0.0
|
||
|
||
query_length = len(query)
|
||
if query_length < self.short_query_threshold:
|
||
score += 0.0
|
||
elif query_length > self.long_query_threshold:
|
||
score += 0.3
|
||
else:
|
||
score += 0.15
|
||
|
||
condition_count = 0
|
||
for pattern in self.condition_patterns:
|
||
matches = re.findall(pattern, query)
|
||
condition_count += len(matches)
|
||
|
||
if condition_count >= 3:
|
||
score += 0.3
|
||
elif condition_count >= 2:
|
||
score += 0.2
|
||
elif condition_count >= 1:
|
||
score += 0.1
|
||
|
||
for pattern in self.reasoning_patterns:
|
||
if re.search(pattern, query):
|
||
score += 0.15
|
||
break
|
||
|
||
for pattern in self.cross_domain_patterns:
|
||
if re.search(pattern, query):
|
||
score += 0.15
|
||
break
|
||
|
||
question_marks = query.count("?") + query.count("?")
|
||
if question_marks >= 2:
|
||
score += 0.1
|
||
|
||
return min(1.0, score)
|
||
|
||
|
||
@dataclass
|
||
class ModeRouteResult:
|
||
"""Result from mode routing decision."""
|
||
mode: RagRuntimeMode
|
||
confidence: float
|
||
complexity_score: float
|
||
should_fallback_to_react: bool = False
|
||
fallback_reason: str | None = None
|
||
diagnostics: dict[str, Any] = field(default_factory=dict)
|
||
|
||
|
||
class DirectRetrievalExecutor:
|
||
"""
|
||
[AC-AISVC-RES-09] Direct retrieval executor for low-latency path.
|
||
|
||
Single KB call without multi-step reasoning.
|
||
"""
|
||
|
||
def __init__(self):
|
||
self._retriever = None
|
||
|
||
async def execute(
|
||
self,
|
||
ctx: StrategyContext,
|
||
) -> "RetrievalResult":
|
||
"""
|
||
Execute direct retrieval (single KB call).
|
||
"""
|
||
from app.services.retrieval.optimized_retriever import get_optimized_retriever
|
||
from app.services.retrieval.base import RetrievalContext
|
||
|
||
if self._retriever is None:
|
||
self._retriever = await get_optimized_retriever()
|
||
|
||
retrieval_ctx = RetrievalContext(
|
||
tenant_id=ctx.tenant_id,
|
||
query=ctx.query,
|
||
metadata_filter=ctx.metadata_filter,
|
||
kb_ids=ctx.kb_ids,
|
||
)
|
||
|
||
return await self._retriever.retrieve(retrieval_ctx)
|
||
|
||
|
||
class ReactRetrievalExecutor:
|
||
"""
|
||
[AC-AISVC-RES-10] ReAct retrieval executor for multi-step path.
|
||
|
||
Uses AgentOrchestrator for multi-step reasoning and KB calls.
|
||
"""
|
||
|
||
def __init__(self, max_steps: int = 5):
|
||
self._max_steps = max_steps
|
||
|
||
async def execute(
|
||
self,
|
||
ctx: StrategyContext,
|
||
config: RoutingConfig,
|
||
) -> tuple[str, "RetrievalResult | None", dict[str, Any]]:
|
||
"""
|
||
Execute ReAct retrieval (multi-step reasoning).
|
||
|
||
Returns:
|
||
Tuple of (final_answer, retrieval_result, react_context)
|
||
"""
|
||
from app.services.mid.agent_orchestrator import AgentOrchestrator, AgentMode
|
||
from app.services.mid.tool_registry import ToolRegistry
|
||
from app.services.mid.timeout_governor import TimeoutGovernor
|
||
from app.services.llm.factory import get_llm_config_manager
|
||
|
||
try:
|
||
llm_manager = get_llm_config_manager()
|
||
llm_client = llm_manager.get_client()
|
||
except Exception as e:
|
||
logger.warning(f"[ModeRouter] Failed to get LLM client: {e}")
|
||
llm_client = None
|
||
|
||
tool_registry = ToolRegistry(timeout_governor=TimeoutGovernor())
|
||
timeout_governor = TimeoutGovernor()
|
||
|
||
orchestrator = AgentOrchestrator(
|
||
max_iterations=min(config.react_max_steps, self._max_steps),
|
||
timeout_governor=timeout_governor,
|
||
llm_client=llm_client,
|
||
tool_registry=tool_registry,
|
||
tenant_id=ctx.tenant_id,
|
||
mode=AgentMode.FUNCTION_CALLING,
|
||
)
|
||
|
||
base_context = {
|
||
"query": ctx.query,
|
||
"metadata_filter": ctx.metadata_filter,
|
||
"kb_ids": ctx.kb_ids,
|
||
**ctx.additional_context,
|
||
}
|
||
|
||
final_answer, react_ctx, trace = await orchestrator.execute(
|
||
user_message=ctx.query,
|
||
context=base_context,
|
||
)
|
||
|
||
return final_answer, None, {
|
||
"iterations": react_ctx.iteration,
|
||
"tool_calls": [tc.model_dump() for tc in react_ctx.tool_calls] if react_ctx.tool_calls else [],
|
||
"final_answer": final_answer,
|
||
}
|
||
|
||
|
||
class ModeRouter:
|
||
"""
|
||
[AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11]
|
||
Mode router for RAG runtime mode selection.
|
||
|
||
Mode Selection:
|
||
- direct: Low-latency generic retrieval (single KB call)
|
||
- react: Multi-step ReAct retrieval (high accuracy)
|
||
- auto: Automatic selection based on complexity/confidence
|
||
|
||
Auto Mode Rules:
|
||
- Direct conditions:
|
||
- Short query, clear intent
|
||
- High metadata confidence
|
||
- No cross-domain/multi-condition
|
||
- React conditions:
|
||
- Multi-condition/multi-constraint
|
||
- Low metadata confidence
|
||
- Need for secondary confirmation or multi-step reasoning
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
config: RoutingConfig | None = None,
|
||
):
|
||
self._config = config or RoutingConfig()
|
||
self._complexity_analyzer = ComplexityAnalyzer()
|
||
self._direct_executor = DirectRetrievalExecutor()
|
||
self._react_executor = ReactRetrievalExecutor(
|
||
max_steps=self._config.react_max_steps
|
||
)
|
||
|
||
@property
|
||
def config(self) -> RoutingConfig:
|
||
"""Get current configuration."""
|
||
return self._config
|
||
|
||
def update_config(self, new_config: RoutingConfig) -> None:
|
||
"""
|
||
[AC-AISVC-RES-15] Update routing configuration (hot reload).
|
||
"""
|
||
self._config = new_config
|
||
self._react_executor._max_steps = new_config.react_max_steps
|
||
|
||
logger.info(
|
||
f"[AC-AISVC-RES-15] ModeRouter config updated: "
|
||
f"mode={new_config.rag_runtime_mode.value}, "
|
||
f"react_max_steps={new_config.react_max_steps}, "
|
||
f"confidence_threshold={new_config.react_trigger_confidence_threshold}"
|
||
)
|
||
|
||
def route(
|
||
self,
|
||
ctx: StrategyContext,
|
||
) -> ModeRouteResult:
|
||
"""
|
||
[AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11]
|
||
Route to appropriate mode based on configuration and context.
|
||
|
||
Args:
|
||
ctx: Strategy context with query, metadata, confidence, etc.
|
||
|
||
Returns:
|
||
ModeRouteResult with selected mode and diagnostics
|
||
"""
|
||
configured_mode = self._config.get_rag_runtime_mode()
|
||
|
||
if configured_mode == RagRuntimeMode.DIRECT:
|
||
logger.info(
|
||
f"[AC-AISVC-RES-09] Mode routing to DIRECT: tenant={ctx.tenant_id}"
|
||
)
|
||
return ModeRouteResult(
|
||
mode=RagRuntimeMode.DIRECT,
|
||
confidence=ctx.metadata_confidence,
|
||
complexity_score=ctx.complexity_score,
|
||
diagnostics={"configured_mode": "direct"},
|
||
)
|
||
|
||
if configured_mode == RagRuntimeMode.REACT:
|
||
logger.info(
|
||
f"[AC-AISVC-RES-10] Mode routing to REACT: tenant={ctx.tenant_id}"
|
||
)
|
||
return ModeRouteResult(
|
||
mode=RagRuntimeMode.REACT,
|
||
confidence=ctx.metadata_confidence,
|
||
complexity_score=ctx.complexity_score,
|
||
diagnostics={"configured_mode": "react"},
|
||
)
|
||
|
||
complexity_score = self._complexity_analyzer.analyze(ctx.query)
|
||
effective_complexity = max(complexity_score, ctx.complexity_score)
|
||
|
||
should_use_react = self._config.should_trigger_react_in_auto_mode(
|
||
confidence=ctx.metadata_confidence,
|
||
complexity_score=effective_complexity,
|
||
)
|
||
|
||
selected_mode = RagRuntimeMode.REACT if should_use_react else RagRuntimeMode.DIRECT
|
||
|
||
logger.info(
|
||
f"[AC-AISVC-RES-11, AC-AISVC-RES-12, AC-AISVC-RES-13] "
|
||
f"Auto mode routing: selected={selected_mode.value}, "
|
||
f"confidence={ctx.metadata_confidence:.2f}, "
|
||
f"complexity={effective_complexity:.2f}, "
|
||
f"tenant={ctx.tenant_id}"
|
||
)
|
||
|
||
return ModeRouteResult(
|
||
mode=selected_mode,
|
||
confidence=ctx.metadata_confidence,
|
||
complexity_score=effective_complexity,
|
||
diagnostics={
|
||
"configured_mode": "auto",
|
||
"analyzed_complexity": complexity_score,
|
||
"provided_complexity": ctx.complexity_score,
|
||
"react_trigger_confidence": self._config.react_trigger_confidence_threshold,
|
||
"react_trigger_complexity": self._config.react_trigger_complexity_score,
|
||
},
|
||
)
|
||
|
||
async def execute_direct(
|
||
self,
|
||
ctx: StrategyContext,
|
||
) -> "RetrievalResult":
|
||
"""
|
||
Execute direct retrieval mode.
|
||
"""
|
||
return await self._direct_executor.execute(ctx)
|
||
|
||
async def execute_react(
|
||
self,
|
||
ctx: StrategyContext,
|
||
) -> tuple[str, "RetrievalResult | None", dict[str, Any]]:
|
||
"""
|
||
Execute ReAct retrieval mode.
|
||
"""
|
||
return await self._react_executor.execute(ctx, self._config)
|
||
|
||
async def execute_with_fallback(
|
||
self,
|
||
ctx: StrategyContext,
|
||
) -> tuple["RetrievalResult | None", str | None, ModeRouteResult]:
|
||
"""
|
||
[AC-AISVC-RES-14] Execute with fallback from direct to react on low confidence.
|
||
|
||
Args:
|
||
ctx: Strategy context
|
||
|
||
Returns:
|
||
Tuple of (RetrievalResult or None, final_answer or None, ModeRouteResult)
|
||
"""
|
||
route_result = self.route(ctx)
|
||
|
||
if route_result.mode == RagRuntimeMode.DIRECT:
|
||
retrieval_result = await self._direct_executor.execute(ctx)
|
||
|
||
max_score = 0.0
|
||
if retrieval_result and retrieval_result.hits:
|
||
max_score = max((h.score for h in retrieval_result.hits), default=0.0)
|
||
|
||
if self._config.should_fallback_direct_to_react(max_score):
|
||
logger.info(
|
||
f"[AC-AISVC-RES-14] Direct mode low confidence fallback to react: "
|
||
f"confidence={max_score:.2f}, threshold={self._config.direct_fallback_confidence_threshold}"
|
||
)
|
||
|
||
final_answer, _, react_ctx = await self._react_executor.execute(ctx, self._config)
|
||
|
||
return (
|
||
None,
|
||
final_answer,
|
||
ModeRouteResult(
|
||
mode=RagRuntimeMode.REACT,
|
||
confidence=max_score,
|
||
complexity_score=route_result.complexity_score,
|
||
should_fallback_to_react=True,
|
||
fallback_reason="low_confidence",
|
||
diagnostics={
|
||
**route_result.diagnostics,
|
||
"fallback_from": "direct",
|
||
"direct_confidence": max_score,
|
||
},
|
||
),
|
||
)
|
||
|
||
return retrieval_result, None, route_result
|
||
|
||
final_answer, _, react_ctx = await self._react_executor.execute(ctx, self._config)
|
||
|
||
return None, final_answer, route_result
|
||
|
||
|
||
_mode_router: ModeRouter | None = None
|
||
|
||
|
||
def get_mode_router() -> ModeRouter:
|
||
"""Get or create ModeRouter singleton."""
|
||
global _mode_router
|
||
if _mode_router is None:
|
||
_mode_router = ModeRouter()
|
||
return _mode_router
|
||
|
||
|
||
def reset_mode_router() -> None:
|
||
"""Reset ModeRouter singleton (for testing)."""
|
||
global _mode_router
|
||
_mode_router = None
|