439 lines
14 KiB
Python
439 lines
14 KiB
Python
|
|
"""
|
|||
|
|
Mode Router for RAG Runtime Mode Selection.
|
|||
|
|
[AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11] Routes to direct/react/auto mode.
|
|||
|
|
|
|||
|
|
Mode Descriptions:
|
|||
|
|
- direct: Low-latency generic retrieval path (single KB call)
|
|||
|
|
- react: Multi-step ReAct retrieval path (high accuracy)
|
|||
|
|
- auto: Automatic selection based on complexity/confidence rules
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from __future__ import annotations
|
|||
|
|
|
|||
|
|
import logging
|
|||
|
|
import re
|
|||
|
|
import time
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from typing import TYPE_CHECKING, Any
|
|||
|
|
|
|||
|
|
from app.services.retrieval.routing_config import (
|
|||
|
|
RagRuntimeMode,
|
|||
|
|
RoutingConfig,
|
|||
|
|
StrategyContext,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if TYPE_CHECKING:
|
|||
|
|
from app.services.retrieval.base import RetrievalResult
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ComplexityAnalyzer:
|
|||
|
|
"""
|
|||
|
|
Analyzes query complexity for mode routing decisions.
|
|||
|
|
|
|||
|
|
Complexity factors:
|
|||
|
|
- Query length
|
|||
|
|
- Number of conditions/constraints
|
|||
|
|
- Presence of logical operators (and, or, not)
|
|||
|
|
- Cross-domain indicators
|
|||
|
|
- Multi-step reasoning requirements
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
short_query_threshold: int = 20
|
|||
|
|
long_query_threshold: int = 100
|
|||
|
|
|
|||
|
|
condition_patterns: list[str] = field(default_factory=lambda: [
|
|||
|
|
r"和|与|及|并且|同时",
|
|||
|
|
r"或者|还是|要么",
|
|||
|
|
r"但是|不过|然而",
|
|||
|
|
r"如果|假如|假设",
|
|||
|
|
r"既.*又",
|
|||
|
|
r"不仅.*而且",
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
reasoning_patterns: list[str] = field(default_factory=lambda: [
|
|||
|
|
r"为什么|原因|理由",
|
|||
|
|
r"怎么|如何|怎样",
|
|||
|
|
r"区别|差异|不同",
|
|||
|
|
r"比较|对比|优劣",
|
|||
|
|
r"分析|评估|判断",
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
cross_domain_patterns: list[str] = field(default_factory=lambda: [
|
|||
|
|
r"跨|多|各个",
|
|||
|
|
r"所有|全部|整体",
|
|||
|
|
r"综合|汇总|统计",
|
|||
|
|
])
|
|||
|
|
|
|||
|
|
def analyze(self, query: str) -> float:
|
|||
|
|
"""
|
|||
|
|
Analyze query complexity and return a score (0.0 ~ 1.0).
|
|||
|
|
|
|||
|
|
Higher score indicates more complex query that may benefit from ReAct mode.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
query: User query text
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Complexity score (0.0 = simple, 1.0 = very complex)
|
|||
|
|
"""
|
|||
|
|
if not query:
|
|||
|
|
return 0.0
|
|||
|
|
|
|||
|
|
score = 0.0
|
|||
|
|
|
|||
|
|
query_length = len(query)
|
|||
|
|
if query_length < self.short_query_threshold:
|
|||
|
|
score += 0.0
|
|||
|
|
elif query_length > self.long_query_threshold:
|
|||
|
|
score += 0.3
|
|||
|
|
else:
|
|||
|
|
score += 0.15
|
|||
|
|
|
|||
|
|
condition_count = 0
|
|||
|
|
for pattern in self.condition_patterns:
|
|||
|
|
matches = re.findall(pattern, query)
|
|||
|
|
condition_count += len(matches)
|
|||
|
|
|
|||
|
|
if condition_count >= 3:
|
|||
|
|
score += 0.3
|
|||
|
|
elif condition_count >= 2:
|
|||
|
|
score += 0.2
|
|||
|
|
elif condition_count >= 1:
|
|||
|
|
score += 0.1
|
|||
|
|
|
|||
|
|
for pattern in self.reasoning_patterns:
|
|||
|
|
if re.search(pattern, query):
|
|||
|
|
score += 0.15
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
for pattern in self.cross_domain_patterns:
|
|||
|
|
if re.search(pattern, query):
|
|||
|
|
score += 0.15
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
question_marks = query.count("?") + query.count("?")
|
|||
|
|
if question_marks >= 2:
|
|||
|
|
score += 0.1
|
|||
|
|
|
|||
|
|
return min(1.0, score)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class ModeRouteResult:
|
|||
|
|
"""Result from mode routing decision."""
|
|||
|
|
mode: RagRuntimeMode
|
|||
|
|
confidence: float
|
|||
|
|
complexity_score: float
|
|||
|
|
should_fallback_to_react: bool = False
|
|||
|
|
fallback_reason: str | None = None
|
|||
|
|
diagnostics: dict[str, Any] = field(default_factory=dict)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class DirectRetrievalExecutor:
|
|||
|
|
"""
|
|||
|
|
[AC-AISVC-RES-09] Direct retrieval executor for low-latency path.
|
|||
|
|
|
|||
|
|
Single KB call without multi-step reasoning.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self):
|
|||
|
|
self._retriever = None
|
|||
|
|
|
|||
|
|
async def execute(
|
|||
|
|
self,
|
|||
|
|
ctx: StrategyContext,
|
|||
|
|
) -> "RetrievalResult":
|
|||
|
|
"""
|
|||
|
|
Execute direct retrieval (single KB call).
|
|||
|
|
"""
|
|||
|
|
from app.services.retrieval.optimized_retriever import get_optimized_retriever
|
|||
|
|
from app.services.retrieval.base import RetrievalContext
|
|||
|
|
|
|||
|
|
if self._retriever is None:
|
|||
|
|
self._retriever = await get_optimized_retriever()
|
|||
|
|
|
|||
|
|
retrieval_ctx = RetrievalContext(
|
|||
|
|
tenant_id=ctx.tenant_id,
|
|||
|
|
query=ctx.query,
|
|||
|
|
metadata_filter=ctx.metadata_filter,
|
|||
|
|
kb_ids=ctx.kb_ids,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return await self._retriever.retrieve(retrieval_ctx)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ReactRetrievalExecutor:
|
|||
|
|
"""
|
|||
|
|
[AC-AISVC-RES-10] ReAct retrieval executor for multi-step path.
|
|||
|
|
|
|||
|
|
Uses AgentOrchestrator for multi-step reasoning and KB calls.
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(self, max_steps: int = 5):
|
|||
|
|
self._max_steps = max_steps
|
|||
|
|
|
|||
|
|
async def execute(
|
|||
|
|
self,
|
|||
|
|
ctx: StrategyContext,
|
|||
|
|
config: RoutingConfig,
|
|||
|
|
) -> tuple[str, "RetrievalResult | None", dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
Execute ReAct retrieval (multi-step reasoning).
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Tuple of (final_answer, retrieval_result, react_context)
|
|||
|
|
"""
|
|||
|
|
from app.services.mid.agent_orchestrator import AgentOrchestrator, AgentMode
|
|||
|
|
from app.services.mid.tool_registry import ToolRegistry
|
|||
|
|
from app.services.mid.timeout_governor import TimeoutGovernor
|
|||
|
|
from app.services.llm.factory import get_llm_config_manager
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
llm_manager = get_llm_config_manager()
|
|||
|
|
llm_client = llm_manager.get_client()
|
|||
|
|
except Exception as e:
|
|||
|
|
logger.warning(f"[ModeRouter] Failed to get LLM client: {e}")
|
|||
|
|
llm_client = None
|
|||
|
|
|
|||
|
|
tool_registry = ToolRegistry(timeout_governor=TimeoutGovernor())
|
|||
|
|
timeout_governor = TimeoutGovernor()
|
|||
|
|
|
|||
|
|
orchestrator = AgentOrchestrator(
|
|||
|
|
max_iterations=min(config.react_max_steps, self._max_steps),
|
|||
|
|
timeout_governor=timeout_governor,
|
|||
|
|
llm_client=llm_client,
|
|||
|
|
tool_registry=tool_registry,
|
|||
|
|
tenant_id=ctx.tenant_id,
|
|||
|
|
mode=AgentMode.FUNCTION_CALLING,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
base_context = {
|
|||
|
|
"query": ctx.query,
|
|||
|
|
"metadata_filter": ctx.metadata_filter,
|
|||
|
|
"kb_ids": ctx.kb_ids,
|
|||
|
|
**ctx.additional_context,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
final_answer, react_ctx, trace = await orchestrator.execute(
|
|||
|
|
user_message=ctx.query,
|
|||
|
|
context=base_context,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return final_answer, None, {
|
|||
|
|
"iterations": react_ctx.iteration,
|
|||
|
|
"tool_calls": [tc.model_dump() for tc in react_ctx.tool_calls] if react_ctx.tool_calls else [],
|
|||
|
|
"final_answer": final_answer,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ModeRouter:
|
|||
|
|
"""
|
|||
|
|
[AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11]
|
|||
|
|
Mode router for RAG runtime mode selection.
|
|||
|
|
|
|||
|
|
Mode Selection:
|
|||
|
|
- direct: Low-latency generic retrieval (single KB call)
|
|||
|
|
- react: Multi-step ReAct retrieval (high accuracy)
|
|||
|
|
- auto: Automatic selection based on complexity/confidence
|
|||
|
|
|
|||
|
|
Auto Mode Rules:
|
|||
|
|
- Direct conditions:
|
|||
|
|
- Short query, clear intent
|
|||
|
|
- High metadata confidence
|
|||
|
|
- No cross-domain/multi-condition
|
|||
|
|
- React conditions:
|
|||
|
|
- Multi-condition/multi-constraint
|
|||
|
|
- Low metadata confidence
|
|||
|
|
- Need for secondary confirmation or multi-step reasoning
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
def __init__(
|
|||
|
|
self,
|
|||
|
|
config: RoutingConfig | None = None,
|
|||
|
|
):
|
|||
|
|
self._config = config or RoutingConfig()
|
|||
|
|
self._complexity_analyzer = ComplexityAnalyzer()
|
|||
|
|
self._direct_executor = DirectRetrievalExecutor()
|
|||
|
|
self._react_executor = ReactRetrievalExecutor(
|
|||
|
|
max_steps=self._config.react_max_steps
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
@property
|
|||
|
|
def config(self) -> RoutingConfig:
|
|||
|
|
"""Get current configuration."""
|
|||
|
|
return self._config
|
|||
|
|
|
|||
|
|
def update_config(self, new_config: RoutingConfig) -> None:
|
|||
|
|
"""
|
|||
|
|
[AC-AISVC-RES-15] Update routing configuration (hot reload).
|
|||
|
|
"""
|
|||
|
|
self._config = new_config
|
|||
|
|
self._react_executor._max_steps = new_config.react_max_steps
|
|||
|
|
|
|||
|
|
logger.info(
|
|||
|
|
f"[AC-AISVC-RES-15] ModeRouter config updated: "
|
|||
|
|
f"mode={new_config.rag_runtime_mode.value}, "
|
|||
|
|
f"react_max_steps={new_config.react_max_steps}, "
|
|||
|
|
f"confidence_threshold={new_config.react_trigger_confidence_threshold}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
def route(
|
|||
|
|
self,
|
|||
|
|
ctx: StrategyContext,
|
|||
|
|
) -> ModeRouteResult:
|
|||
|
|
"""
|
|||
|
|
[AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11]
|
|||
|
|
Route to appropriate mode based on configuration and context.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
ctx: Strategy context with query, metadata, confidence, etc.
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
ModeRouteResult with selected mode and diagnostics
|
|||
|
|
"""
|
|||
|
|
configured_mode = self._config.get_rag_runtime_mode()
|
|||
|
|
|
|||
|
|
if configured_mode == RagRuntimeMode.DIRECT:
|
|||
|
|
logger.info(
|
|||
|
|
f"[AC-AISVC-RES-09] Mode routing to DIRECT: tenant={ctx.tenant_id}"
|
|||
|
|
)
|
|||
|
|
return ModeRouteResult(
|
|||
|
|
mode=RagRuntimeMode.DIRECT,
|
|||
|
|
confidence=ctx.metadata_confidence,
|
|||
|
|
complexity_score=ctx.complexity_score,
|
|||
|
|
diagnostics={"configured_mode": "direct"},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
if configured_mode == RagRuntimeMode.REACT:
|
|||
|
|
logger.info(
|
|||
|
|
f"[AC-AISVC-RES-10] Mode routing to REACT: tenant={ctx.tenant_id}"
|
|||
|
|
)
|
|||
|
|
return ModeRouteResult(
|
|||
|
|
mode=RagRuntimeMode.REACT,
|
|||
|
|
confidence=ctx.metadata_confidence,
|
|||
|
|
complexity_score=ctx.complexity_score,
|
|||
|
|
diagnostics={"configured_mode": "react"},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
complexity_score = self._complexity_analyzer.analyze(ctx.query)
|
|||
|
|
effective_complexity = max(complexity_score, ctx.complexity_score)
|
|||
|
|
|
|||
|
|
should_use_react = self._config.should_trigger_react_in_auto_mode(
|
|||
|
|
confidence=ctx.metadata_confidence,
|
|||
|
|
complexity_score=effective_complexity,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
selected_mode = RagRuntimeMode.REACT if should_use_react else RagRuntimeMode.DIRECT
|
|||
|
|
|
|||
|
|
logger.info(
|
|||
|
|
f"[AC-AISVC-RES-11, AC-AISVC-RES-12, AC-AISVC-RES-13] "
|
|||
|
|
f"Auto mode routing: selected={selected_mode.value}, "
|
|||
|
|
f"confidence={ctx.metadata_confidence:.2f}, "
|
|||
|
|
f"complexity={effective_complexity:.2f}, "
|
|||
|
|
f"tenant={ctx.tenant_id}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return ModeRouteResult(
|
|||
|
|
mode=selected_mode,
|
|||
|
|
confidence=ctx.metadata_confidence,
|
|||
|
|
complexity_score=effective_complexity,
|
|||
|
|
diagnostics={
|
|||
|
|
"configured_mode": "auto",
|
|||
|
|
"analyzed_complexity": complexity_score,
|
|||
|
|
"provided_complexity": ctx.complexity_score,
|
|||
|
|
"react_trigger_confidence": self._config.react_trigger_confidence_threshold,
|
|||
|
|
"react_trigger_complexity": self._config.react_trigger_complexity_score,
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
async def execute_direct(
|
|||
|
|
self,
|
|||
|
|
ctx: StrategyContext,
|
|||
|
|
) -> "RetrievalResult":
|
|||
|
|
"""
|
|||
|
|
Execute direct retrieval mode.
|
|||
|
|
"""
|
|||
|
|
return await self._direct_executor.execute(ctx)
|
|||
|
|
|
|||
|
|
async def execute_react(
|
|||
|
|
self,
|
|||
|
|
ctx: StrategyContext,
|
|||
|
|
) -> tuple[str, "RetrievalResult | None", dict[str, Any]]:
|
|||
|
|
"""
|
|||
|
|
Execute ReAct retrieval mode.
|
|||
|
|
"""
|
|||
|
|
return await self._react_executor.execute(ctx, self._config)
|
|||
|
|
|
|||
|
|
async def execute_with_fallback(
|
|||
|
|
self,
|
|||
|
|
ctx: StrategyContext,
|
|||
|
|
) -> tuple["RetrievalResult | None", str | None, ModeRouteResult]:
|
|||
|
|
"""
|
|||
|
|
[AC-AISVC-RES-14] Execute with fallback from direct to react on low confidence.
|
|||
|
|
|
|||
|
|
Args:
|
|||
|
|
ctx: Strategy context
|
|||
|
|
|
|||
|
|
Returns:
|
|||
|
|
Tuple of (RetrievalResult or None, final_answer or None, ModeRouteResult)
|
|||
|
|
"""
|
|||
|
|
route_result = self.route(ctx)
|
|||
|
|
|
|||
|
|
if route_result.mode == RagRuntimeMode.DIRECT:
|
|||
|
|
retrieval_result = await self._direct_executor.execute(ctx)
|
|||
|
|
|
|||
|
|
max_score = 0.0
|
|||
|
|
if retrieval_result and retrieval_result.hits:
|
|||
|
|
max_score = max((h.score for h in retrieval_result.hits), default=0.0)
|
|||
|
|
|
|||
|
|
if self._config.should_fallback_direct_to_react(max_score):
|
|||
|
|
logger.info(
|
|||
|
|
f"[AC-AISVC-RES-14] Direct mode low confidence fallback to react: "
|
|||
|
|
f"confidence={max_score:.2f}, threshold={self._config.direct_fallback_confidence_threshold}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
final_answer, _, react_ctx = await self._react_executor.execute(ctx, self._config)
|
|||
|
|
|
|||
|
|
return (
|
|||
|
|
None,
|
|||
|
|
final_answer,
|
|||
|
|
ModeRouteResult(
|
|||
|
|
mode=RagRuntimeMode.REACT,
|
|||
|
|
confidence=max_score,
|
|||
|
|
complexity_score=route_result.complexity_score,
|
|||
|
|
should_fallback_to_react=True,
|
|||
|
|
fallback_reason="low_confidence",
|
|||
|
|
diagnostics={
|
|||
|
|
**route_result.diagnostics,
|
|||
|
|
"fallback_from": "direct",
|
|||
|
|
"direct_confidence": max_score,
|
|||
|
|
},
|
|||
|
|
),
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return retrieval_result, None, route_result
|
|||
|
|
|
|||
|
|
final_answer, _, react_ctx = await self._react_executor.execute(ctx, self._config)
|
|||
|
|
|
|||
|
|
return None, final_answer, route_result
|
|||
|
|
|
|||
|
|
|
|||
|
|
_mode_router: ModeRouter | None = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
def get_mode_router() -> ModeRouter:
|
|||
|
|
"""Get or create ModeRouter singleton."""
|
|||
|
|
global _mode_router
|
|||
|
|
if _mode_router is None:
|
|||
|
|
_mode_router = ModeRouter()
|
|||
|
|
return _mode_router
|
|||
|
|
|
|||
|
|
|
|||
|
|
def reset_mode_router() -> None:
|
|||
|
|
"""Reset ModeRouter singleton (for testing)."""
|
|||
|
|
global _mode_router
|
|||
|
|
_mode_router = None
|