[AC-AISVC-RES-01~15] feat(retrieval): 实现检索策略路由核心模块

- 新增 routing_config.py 路由配置模型
  - StrategyType: DEFAULT/ENHANCED 策略类型
  - RagRuntimeMode: DIRECT/REACT/AUTO 运行模式
  - RoutingConfig: 路由配置类
  - StrategyContext: 策略上下文
  - StrategyResult: 策略结果

- 新增 strategy_router.py 策略路由器
  - RollbackManager: 回滚管理器
  - DefaultPipeline: 默认检索管道
  - EnhancedPipeline: 增强检索管道
  - StrategyRouter: 策略路由器

- 新增 mode_router.py 模式路由器
  - ComplexityAnalyzer: 复杂度分析器
  - ModeRouter: 模式路由器

- 新增 strategy_integration.py 统一集成层
  - RetrievalStrategyIntegration: 策略集成器

- 更新 __init__.py 导出新模块
This commit is contained in:
MerCry 2026-03-10 21:07:01 +08:00
parent 2476da8957
commit c628181623
5 changed files with 1300 additions and 0 deletions

View File

@ -2,6 +2,7 @@
Retrieval module for AI Service.
[AC-AISVC-16] Provides retriever implementations with plugin architecture.
RAG Optimization: Two-stage retrieval, RRF hybrid ranking, metadata filtering.
[AC-AISVC-RES-01~15] Strategy routing and mode routing for retrieval pipeline.
"""
from app.services.retrieval.base import (
@ -32,6 +33,29 @@ from app.services.retrieval.optimized_retriever import (
get_optimized_retriever,
)
from app.services.retrieval.vector_retriever import VectorRetriever, get_vector_retriever
from app.services.retrieval.routing_config import (
RagRuntimeMode,
RoutingConfig,
StrategyContext,
StrategyType,
StrategyResult,
)
from app.services.retrieval.strategy_router import (
RollbackRecord,
StrategyRouter,
get_strategy_router,
)
from app.services.retrieval.mode_router import (
ComplexityAnalyzer,
ModeRouter,
ModeRouteResult,
get_mode_router,
)
from app.services.retrieval.strategy_integration import (
RetrievalStrategyIntegration,
RetrievalStrategyResult,
get_retrieval_strategy_integration,
)
__all__ = [
"BaseRetriever",
@ -55,4 +79,19 @@ __all__ = [
"get_knowledge_indexer",
"IndexingProgress",
"IndexingResult",
"RagRuntimeMode",
"RoutingConfig",
"StrategyContext",
"StrategyType",
"StrategyResult",
"RollbackRecord",
"StrategyRouter",
"get_strategy_router",
"ComplexityAnalyzer",
"ModeRouter",
"ModeRouteResult",
"get_mode_router",
"RetrievalStrategyIntegration",
"RetrievalStrategyResult",
"get_retrieval_strategy_integration",
]

View File

@ -0,0 +1,438 @@
"""
Mode Router for RAG Runtime Mode Selection.
[AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11] Routes to direct/react/auto mode.
Mode Descriptions:
- direct: Low-latency generic retrieval path (single KB call)
- react: Multi-step ReAct retrieval path (high accuracy)
- auto: Automatic selection based on complexity/confidence rules
"""
from __future__ import annotations
import logging
import re
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from app.services.retrieval.routing_config import (
RagRuntimeMode,
RoutingConfig,
StrategyContext,
)
if TYPE_CHECKING:
from app.services.retrieval.base import RetrievalResult
logger = logging.getLogger(__name__)
@dataclass
class ComplexityAnalyzer:
"""
Analyzes query complexity for mode routing decisions.
Complexity factors:
- Query length
- Number of conditions/constraints
- Presence of logical operators (and, or, not)
- Cross-domain indicators
- Multi-step reasoning requirements
"""
short_query_threshold: int = 20
long_query_threshold: int = 100
condition_patterns: list[str] = field(default_factory=lambda: [
r"和|与|及|并且|同时",
r"或者|还是|要么",
r"但是|不过|然而",
r"如果|假如|假设",
r"既.*又",
r"不仅.*而且",
])
reasoning_patterns: list[str] = field(default_factory=lambda: [
r"为什么|原因|理由",
r"怎么|如何|怎样",
r"区别|差异|不同",
r"比较|对比|优劣",
r"分析|评估|判断",
])
cross_domain_patterns: list[str] = field(default_factory=lambda: [
r"跨|多|各个",
r"所有|全部|整体",
r"综合|汇总|统计",
])
def analyze(self, query: str) -> float:
"""
Analyze query complexity and return a score (0.0 ~ 1.0).
Higher score indicates more complex query that may benefit from ReAct mode.
Args:
query: User query text
Returns:
Complexity score (0.0 = simple, 1.0 = very complex)
"""
if not query:
return 0.0
score = 0.0
query_length = len(query)
if query_length < self.short_query_threshold:
score += 0.0
elif query_length > self.long_query_threshold:
score += 0.3
else:
score += 0.15
condition_count = 0
for pattern in self.condition_patterns:
matches = re.findall(pattern, query)
condition_count += len(matches)
if condition_count >= 3:
score += 0.3
elif condition_count >= 2:
score += 0.2
elif condition_count >= 1:
score += 0.1
for pattern in self.reasoning_patterns:
if re.search(pattern, query):
score += 0.15
break
for pattern in self.cross_domain_patterns:
if re.search(pattern, query):
score += 0.15
break
question_marks = query.count("?") + query.count("")
if question_marks >= 2:
score += 0.1
return min(1.0, score)
@dataclass
class ModeRouteResult:
"""Result from mode routing decision."""
mode: RagRuntimeMode
confidence: float
complexity_score: float
should_fallback_to_react: bool = False
fallback_reason: str | None = None
diagnostics: dict[str, Any] = field(default_factory=dict)
class DirectRetrievalExecutor:
"""
[AC-AISVC-RES-09] Direct retrieval executor for low-latency path.
Single KB call without multi-step reasoning.
"""
def __init__(self):
self._retriever = None
async def execute(
self,
ctx: StrategyContext,
) -> "RetrievalResult":
"""
Execute direct retrieval (single KB call).
"""
from app.services.retrieval.optimized_retriever import get_optimized_retriever
from app.services.retrieval.base import RetrievalContext
if self._retriever is None:
self._retriever = await get_optimized_retriever()
retrieval_ctx = RetrievalContext(
tenant_id=ctx.tenant_id,
query=ctx.query,
metadata_filter=ctx.metadata_filter,
kb_ids=ctx.kb_ids,
)
return await self._retriever.retrieve(retrieval_ctx)
class ReactRetrievalExecutor:
"""
[AC-AISVC-RES-10] ReAct retrieval executor for multi-step path.
Uses AgentOrchestrator for multi-step reasoning and KB calls.
"""
def __init__(self, max_steps: int = 5):
self._max_steps = max_steps
async def execute(
self,
ctx: StrategyContext,
config: RoutingConfig,
) -> tuple[str, "RetrievalResult | None", dict[str, Any]]:
"""
Execute ReAct retrieval (multi-step reasoning).
Returns:
Tuple of (final_answer, retrieval_result, react_context)
"""
from app.services.mid.agent_orchestrator import AgentOrchestrator, AgentMode
from app.services.mid.tool_registry import ToolRegistry
from app.services.mid.timeout_governor import TimeoutGovernor
from app.services.llm.factory import get_llm_config_manager
try:
llm_manager = get_llm_config_manager()
llm_client = llm_manager.get_client()
except Exception as e:
logger.warning(f"[ModeRouter] Failed to get LLM client: {e}")
llm_client = None
tool_registry = ToolRegistry(timeout_governor=TimeoutGovernor())
timeout_governor = TimeoutGovernor()
orchestrator = AgentOrchestrator(
max_iterations=min(config.react_max_steps, self._max_steps),
timeout_governor=timeout_governor,
llm_client=llm_client,
tool_registry=tool_registry,
tenant_id=ctx.tenant_id,
mode=AgentMode.FUNCTION_CALLING,
)
base_context = {
"query": ctx.query,
"metadata_filter": ctx.metadata_filter,
"kb_ids": ctx.kb_ids,
**ctx.additional_context,
}
final_answer, react_ctx, trace = await orchestrator.execute(
user_message=ctx.query,
context=base_context,
)
return final_answer, None, {
"iterations": react_ctx.iteration,
"tool_calls": [tc.model_dump() for tc in react_ctx.tool_calls] if react_ctx.tool_calls else [],
"final_answer": final_answer,
}
class ModeRouter:
"""
[AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11]
Mode router for RAG runtime mode selection.
Mode Selection:
- direct: Low-latency generic retrieval (single KB call)
- react: Multi-step ReAct retrieval (high accuracy)
- auto: Automatic selection based on complexity/confidence
Auto Mode Rules:
- Direct conditions:
- Short query, clear intent
- High metadata confidence
- No cross-domain/multi-condition
- React conditions:
- Multi-condition/multi-constraint
- Low metadata confidence
- Need for secondary confirmation or multi-step reasoning
"""
def __init__(
self,
config: RoutingConfig | None = None,
):
self._config = config or RoutingConfig()
self._complexity_analyzer = ComplexityAnalyzer()
self._direct_executor = DirectRetrievalExecutor()
self._react_executor = ReactRetrievalExecutor(
max_steps=self._config.react_max_steps
)
@property
def config(self) -> RoutingConfig:
"""Get current configuration."""
return self._config
def update_config(self, new_config: RoutingConfig) -> None:
"""
[AC-AISVC-RES-15] Update routing configuration (hot reload).
"""
self._config = new_config
self._react_executor._max_steps = new_config.react_max_steps
logger.info(
f"[AC-AISVC-RES-15] ModeRouter config updated: "
f"mode={new_config.rag_runtime_mode.value}, "
f"react_max_steps={new_config.react_max_steps}, "
f"confidence_threshold={new_config.react_trigger_confidence_threshold}"
)
def route(
self,
ctx: StrategyContext,
) -> ModeRouteResult:
"""
[AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11]
Route to appropriate mode based on configuration and context.
Args:
ctx: Strategy context with query, metadata, confidence, etc.
Returns:
ModeRouteResult with selected mode and diagnostics
"""
configured_mode = self._config.get_rag_runtime_mode()
if configured_mode == RagRuntimeMode.DIRECT:
logger.info(
f"[AC-AISVC-RES-09] Mode routing to DIRECT: tenant={ctx.tenant_id}"
)
return ModeRouteResult(
mode=RagRuntimeMode.DIRECT,
confidence=ctx.metadata_confidence,
complexity_score=ctx.complexity_score,
diagnostics={"configured_mode": "direct"},
)
if configured_mode == RagRuntimeMode.REACT:
logger.info(
f"[AC-AISVC-RES-10] Mode routing to REACT: tenant={ctx.tenant_id}"
)
return ModeRouteResult(
mode=RagRuntimeMode.REACT,
confidence=ctx.metadata_confidence,
complexity_score=ctx.complexity_score,
diagnostics={"configured_mode": "react"},
)
complexity_score = self._complexity_analyzer.analyze(ctx.query)
effective_complexity = max(complexity_score, ctx.complexity_score)
should_use_react = self._config.should_trigger_react_in_auto_mode(
confidence=ctx.metadata_confidence,
complexity_score=effective_complexity,
)
selected_mode = RagRuntimeMode.REACT if should_use_react else RagRuntimeMode.DIRECT
logger.info(
f"[AC-AISVC-RES-11, AC-AISVC-RES-12, AC-AISVC-RES-13] "
f"Auto mode routing: selected={selected_mode.value}, "
f"confidence={ctx.metadata_confidence:.2f}, "
f"complexity={effective_complexity:.2f}, "
f"tenant={ctx.tenant_id}"
)
return ModeRouteResult(
mode=selected_mode,
confidence=ctx.metadata_confidence,
complexity_score=effective_complexity,
diagnostics={
"configured_mode": "auto",
"analyzed_complexity": complexity_score,
"provided_complexity": ctx.complexity_score,
"react_trigger_confidence": self._config.react_trigger_confidence_threshold,
"react_trigger_complexity": self._config.react_trigger_complexity_score,
},
)
async def execute_direct(
self,
ctx: StrategyContext,
) -> "RetrievalResult":
"""
Execute direct retrieval mode.
"""
return await self._direct_executor.execute(ctx)
async def execute_react(
self,
ctx: StrategyContext,
) -> tuple[str, "RetrievalResult | None", dict[str, Any]]:
"""
Execute ReAct retrieval mode.
"""
return await self._react_executor.execute(ctx, self._config)
async def execute_with_fallback(
self,
ctx: StrategyContext,
) -> tuple["RetrievalResult | None", str | None, ModeRouteResult]:
"""
[AC-AISVC-RES-14] Execute with fallback from direct to react on low confidence.
Args:
ctx: Strategy context
Returns:
Tuple of (RetrievalResult or None, final_answer or None, ModeRouteResult)
"""
route_result = self.route(ctx)
if route_result.mode == RagRuntimeMode.DIRECT:
retrieval_result = await self._direct_executor.execute(ctx)
max_score = 0.0
if retrieval_result and retrieval_result.hits:
max_score = max((h.score for h in retrieval_result.hits), default=0.0)
if self._config.should_fallback_direct_to_react(max_score):
logger.info(
f"[AC-AISVC-RES-14] Direct mode low confidence fallback to react: "
f"confidence={max_score:.2f}, threshold={self._config.direct_fallback_confidence_threshold}"
)
final_answer, _, react_ctx = await self._react_executor.execute(ctx, self._config)
return (
None,
final_answer,
ModeRouteResult(
mode=RagRuntimeMode.REACT,
confidence=max_score,
complexity_score=route_result.complexity_score,
should_fallback_to_react=True,
fallback_reason="low_confidence",
diagnostics={
**route_result.diagnostics,
"fallback_from": "direct",
"direct_confidence": max_score,
},
),
)
return retrieval_result, None, route_result
final_answer, _, react_ctx = await self._react_executor.execute(ctx, self._config)
return None, final_answer, route_result
_mode_router: ModeRouter | None = None
def get_mode_router() -> ModeRouter:
"""Get or create ModeRouter singleton."""
global _mode_router
if _mode_router is None:
_mode_router = ModeRouter()
return _mode_router
def reset_mode_router() -> None:
"""Reset ModeRouter singleton (for testing)."""
global _mode_router
_mode_router = None

View File

@ -0,0 +1,187 @@
"""
Retrieval and Embedding Strategy Configuration.
[AC-AISVC-RES-01~15] Configuration for strategy routing and mode routing.
"""
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
class StrategyType(str, Enum):
"""Strategy type for retrieval pipeline selection."""
DEFAULT = "default"
ENHANCED = "enhanced"
class RagRuntimeMode(str, Enum):
"""RAG runtime mode for execution path selection."""
DIRECT = "direct"
REACT = "react"
AUTO = "auto"
@dataclass
class RoutingConfig:
"""
[AC-AISVC-RES-01~15] Routing configuration for strategy and mode selection.
Configuration hierarchy:
1. Strategy selection (default vs enhanced)
2. Mode selection (direct/react/auto)
3. Auto routing rules (complexity/confidence thresholds)
4. Fallback behavior
"""
enabled: bool = True
strategy: StrategyType = StrategyType.DEFAULT
grayscale_percentage: float = 0.0
grayscale_allowlist: list[str] = field(default_factory=list)
rag_runtime_mode: RagRuntimeMode = RagRuntimeMode.AUTO
react_trigger_confidence_threshold: float = 0.6
react_trigger_complexity_score: float = 0.5
react_max_steps: int = 5
direct_fallback_on_low_confidence: bool = True
direct_fallback_confidence_threshold: float = 0.4
performance_budget_ms: int = 5000
performance_degradation_threshold: float = 0.2
def should_use_enhanced_strategy(self, tenant_id: str | None = None) -> bool:
"""
[AC-AISVC-RES-02, AC-AISVC-RES-03] Determine if enhanced strategy should be used.
Priority:
1. If strategy is explicitly set to ENHANCED, use enhanced
2. If strategy is DEFAULT, use default
3. If grayscale is enabled, check percentage/allowlist
"""
if self.strategy == StrategyType.ENHANCED:
return True
if self.strategy == StrategyType.DEFAULT:
return False
if self.grayscale_percentage > 0:
import hashlib
if tenant_id:
hash_val = int(hashlib.md5(tenant_id.encode()).hexdigest()[:8], 16)
return (hash_val % 100) < (self.grayscale_percentage * 100)
return False
if self.grayscale_allowlist and tenant_id:
return tenant_id in self.grayscale_allowlist
return False
def get_rag_runtime_mode(self) -> RagRuntimeMode:
"""Get the configured RAG runtime mode."""
return self.rag_runtime_mode
def should_fallback_direct_to_react(self, confidence: float) -> bool:
"""
[AC-AISVC-RES-14] Determine if direct mode should fallback to react.
Args:
confidence: Retrieval confidence score (0.0 ~ 1.0)
Returns:
True if fallback should be triggered
"""
if not self.direct_fallback_on_low_confidence:
return False
return confidence < self.direct_fallback_confidence_threshold
def should_trigger_react_in_auto_mode(
self,
confidence: float,
complexity_score: float,
) -> bool:
"""
[AC-AISVC-RES-11, AC-AISVC-RES-12, AC-AISVC-RES-13]
Determine if react mode should be triggered in auto mode.
Direct conditions (优先):
- Short query, clear intent
- High metadata confidence
- No cross-domain/multi-condition
React conditions:
- Multi-condition/multi-constraint
- Low metadata confidence
- Need for secondary confirmation or multi-step reasoning
Args:
confidence: Metadata inference confidence (0.0 ~ 1.0)
complexity_score: Query complexity score (0.0 ~ 1.0)
Returns:
True if react mode should be used
"""
if confidence < self.react_trigger_confidence_threshold:
return True
if complexity_score > self.react_trigger_complexity_score:
return True
return False
def validate(self) -> tuple[bool, list[str]]:
"""
[AC-AISVC-RES-06] Validate configuration consistency.
Returns:
(is_valid, list of error messages)
"""
errors = []
if self.grayscale_percentage < 0 or self.grayscale_percentage > 1.0:
errors.append("grayscale_percentage must be between 0.0 and 1.0")
if self.react_trigger_confidence_threshold < 0 or self.react_trigger_confidence_threshold > 1.0:
errors.append("react_trigger_confidence_threshold must be between 0.0 and 1.0")
if self.react_trigger_complexity_score < 0 or self.react_trigger_complexity_score > 1.0:
errors.append("react_trigger_complexity_score must be between 0.0 and 1.0")
if self.react_max_steps < 3 or self.react_max_steps > 10:
errors.append("react_max_steps must be between 3 and 10")
if self.direct_fallback_confidence_threshold < 0 or self.direct_fallback_confidence_threshold > 1.0:
errors.append("direct_fallback_confidence_threshold must be between 0.0 and 1.0")
if self.performance_budget_ms < 1000:
errors.append("performance_budget_ms must be at least 1000")
if self.performance_degradation_threshold < 0 or self.performance_degradation_threshold > 1.0:
errors.append("performance_degradation_threshold must be between 0.0 and 1.0")
return (len(errors) == 0, errors)
@dataclass
class StrategyContext:
"""Context for strategy routing decision."""
tenant_id: str
query: str
metadata_filter: dict[str, Any] | None = None
metadata_confidence: float = 1.0
complexity_score: float = 0.0
kb_ids: list[str] | None = None
top_k: int = 5
additional_context: dict[str, Any] = field(default_factory=dict)
@dataclass
class StrategyResult:
"""Result from strategy routing."""
strategy: StrategyType
mode: RagRuntimeMode
should_fallback: bool = False
fallback_reason: str | None = None
diagnostics: dict[str, Any] = field(default_factory=dict)

View File

@ -0,0 +1,233 @@
"""
Retrieval Strategy Integration for Dialogue Flow.
[AC-AISVC-RES-01~15] Integrates StrategyRouter and ModeRouter into dialogue pipeline.
Usage:
from app.services.retrieval.strategy_integration import RetrievalStrategyIntegration
integration = RetrievalStrategyIntegration()
result = await integration.execute(ctx)
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from app.services.retrieval.routing_config import (
RagRuntimeMode,
StrategyType,
RoutingConfig,
StrategyContext,
StrategyResult,
)
from app.services.retrieval.strategy_router import (
StrategyRouter,
get_strategy_router,
)
from app.services.retrieval.mode_router import (
ModeRouter,
ModeRouteResult,
get_mode_router,
)
if TYPE_CHECKING:
from app.services.retrieval.base import RetrievalResult
logger = logging.getLogger(__name__)
@dataclass
class RetrievalStrategyResult:
"""Combined result from strategy and mode routing."""
retrieval_result: "RetrievalResult | None"
final_answer: str | None
strategy: StrategyType
mode: RagRuntimeMode
should_fallback: bool = False
fallback_reason: str | None = None
mode_route_result: ModeRouteResult | None = None
diagnostics: dict[str, Any] = field(default_factory=dict)
duration_ms: int = 0
class RetrievalStrategyIntegration:
"""
[AC-AISVC-RES-01~15] Integration layer for retrieval strategy.
Combines StrategyRouter and ModeRouter to provide a unified interface
for the dialogue pipeline.
Flow:
1. StrategyRouter selects default or enhanced strategy
2. ModeRouter selects direct, react, or auto mode
3. Execute retrieval with selected strategy and mode
4. Handle fallback scenarios
"""
def __init__(
self,
config: RoutingConfig | None = None,
strategy_router: StrategyRouter | None = None,
mode_router: ModeRouter | None = None,
):
self._config = config or RoutingConfig()
self._strategy_router = strategy_router or get_strategy_router()
self._mode_router = mode_router or get_mode_router()
@property
def config(self) -> RoutingConfig:
"""Get current configuration."""
return self._config
def update_config(self, new_config: RoutingConfig) -> None:
"""
[AC-AISVC-RES-15] Update all routing configurations.
"""
self._config = new_config
self._strategy_router.update_config(new_config)
self._mode_router.update_config(new_config)
logger.info(
f"[AC-AISVC-RES-15] RetrievalStrategyIntegration config updated: "
f"strategy={new_config.strategy.value}, mode={new_config.rag_runtime_mode.value}"
)
async def execute(
self,
ctx: StrategyContext,
) -> RetrievalStrategyResult:
"""
Execute retrieval with strategy and mode routing.
Args:
ctx: Strategy context with tenant, query, metadata, etc.
Returns:
RetrievalStrategyResult with retrieval results and diagnostics
"""
start_time = time.time()
strategy_result = self._strategy_router.route(ctx)
mode_result = self._mode_router.route(ctx)
logger.info(
f"[AC-AISVC-RES-01~15] Strategy routing: "
f"strategy={strategy_result.strategy.value}, mode={mode_result.mode.value}, "
f"tenant={ctx.tenant_id}, query_len={len(ctx.query)}"
)
retrieval_result = None
final_answer = None
should_fallback = False
fallback_reason = None
try:
if mode_result.mode == RagRuntimeMode.DIRECT:
retrieval_result, answer, mode_result = await self._mode_router.execute_with_fallback(ctx)
if answer is not None:
final_answer = answer
should_fallback = mode_result.should_fallback_to_react
fallback_reason = mode_result.fallback_reason
elif mode_result.mode == RagRuntimeMode.REACT:
answer, retrieval_result, react_ctx = await self._mode_router.execute_react(ctx)
final_answer = answer
else:
retrieval_result, answer, mode_result = await self._mode_router.execute_with_fallback(ctx)
if answer is not None:
final_answer = answer
should_fallback = mode_result.should_fallback_to_react
fallback_reason = mode_result.fallback_reason
except Exception as e:
logger.error(
f"[AC-AISVC-RES-07] Retrieval strategy execution failed: {e}"
)
if strategy_result.strategy == StrategyType.ENHANCED:
self._strategy_router.rollback(
reason=str(e),
tenant_id=ctx.tenant_id,
)
from app.services.retrieval.optimized_retriever import get_optimized_retriever
from app.services.retrieval.base import RetrievalContext
retriever = await get_optimized_retriever()
retrieval_ctx = RetrievalContext(
tenant_id=ctx.tenant_id,
query=ctx.query,
metadata_filter=ctx.metadata_filter,
kb_ids=ctx.kb_ids,
)
retrieval_result = await retriever.retrieve(retrieval_ctx)
should_fallback = True
fallback_reason = str(e)
else:
raise
duration_ms = int((time.time() - start_time) * 1000)
return RetrievalStrategyResult(
retrieval_result=retrieval_result,
final_answer=final_answer,
strategy=strategy_result.strategy,
mode=mode_result.mode,
should_fallback=should_fallback,
fallback_reason=fallback_reason,
mode_route_result=mode_result,
diagnostics={
"strategy_diagnostics": strategy_result.diagnostics,
"mode_diagnostics": mode_result.diagnostics,
"duration_ms": duration_ms,
},
duration_ms=duration_ms,
)
def get_current_strategy(self) -> StrategyType:
"""Get current active strategy."""
return self._strategy_router.current_strategy
def get_rollback_records(self, limit: int = 10) -> list[dict[str, Any]]:
"""Get recent rollback records."""
records = self._strategy_router.get_rollback_records(limit)
return [
{
"timestamp": r.timestamp,
"from_strategy": r.from_strategy.value,
"to_strategy": r.to_strategy.value,
"reason": r.reason,
"tenant_id": r.tenant_id,
}
for r in records
]
def validate_config(self) -> tuple[bool, list[str]]:
"""Validate current configuration."""
return self._config.validate()
_integration: RetrievalStrategyIntegration | None = None
def get_retrieval_strategy_integration() -> RetrievalStrategyIntegration:
"""Get or create RetrievalStrategyIntegration singleton."""
global _integration
if _integration is None:
_integration = RetrievalStrategyIntegration()
return _integration
def reset_retrieval_strategy_integration() -> None:
"""Reset RetrievalStrategyIntegration singleton (for testing)."""
global _integration
_integration = None

View File

@ -0,0 +1,403 @@
"""
Strategy Router for Retrieval and Embedding.
[AC-AISVC-RES-01, AC-AISVC-RES-02, AC-AISVC-RES-03] Routes to default or enhanced strategy.
Key Features:
- Default strategy preserves existing online logic
- Enhanced strategy is configurable and can be rolled back
- Supports grayscale release (percentage/allowlist)
- Supports rollback on error or performance degradation
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from app.services.retrieval.routing_config import (
RagRuntimeMode,
StrategyType,
RoutingConfig,
StrategyContext,
StrategyResult,
)
if TYPE_CHECKING:
from app.services.retrieval.base import RetrievalResult
logger = logging.getLogger(__name__)
@dataclass
class RollbackRecord:
"""Record for strategy rollback event."""
timestamp: float
from_strategy: StrategyType
to_strategy: StrategyType
reason: str
tenant_id: str | None = None
request_id: str | None = None
class RollbackManager:
"""
[AC-AISVC-RES-07] Manages strategy rollback and audit logging.
"""
def __init__(self, max_records: int = 100):
self._records: list[RollbackRecord] = []
self._max_records = max_records
def record_rollback(
self,
from_strategy: StrategyType,
to_strategy: StrategyType,
reason: str,
tenant_id: str | None = None,
request_id: str | None = None,
) -> None:
"""Record a rollback event."""
record = RollbackRecord(
timestamp=time.time(),
from_strategy=from_strategy,
to_strategy=to_strategy,
reason=reason,
tenant_id=tenant_id,
request_id=request_id,
)
self._records.append(record)
if len(self._records) > self._max_records:
self._records = self._records[-self._max_records:]
logger.info(
f"[AC-AISVC-RES-07] Rollback recorded: {from_strategy.value} -> {to_strategy.value}, "
f"reason={reason}, tenant={tenant_id}"
)
def get_recent_rollbacks(self, limit: int = 10) -> list[RollbackRecord]:
"""Get recent rollback records."""
return self._records[-limit:]
def get_rollback_count(self, since_timestamp: float | None = None) -> int:
"""Get count of rollbacks, optionally since a timestamp."""
if since_timestamp is None:
return len(self._records)
return sum(1 for r in self._records if r.timestamp >= since_timestamp)
class DefaultPipeline:
"""
[AC-AISVC-RES-01] Default pipeline that preserves existing online logic.
This pipeline uses the existing OptimizedRetriever without any new features.
"""
def __init__(self):
self._retriever = None
async def execute(
self,
ctx: StrategyContext,
) -> "RetrievalResult":
"""
Execute default retrieval strategy.
Uses existing OptimizedRetriever with current configuration.
"""
from app.services.retrieval.optimized_retriever import get_optimized_retriever
from app.services.retrieval.base import RetrievalContext
if self._retriever is None:
self._retriever = await get_optimized_retriever()
retrieval_ctx = RetrievalContext(
tenant_id=ctx.tenant_id,
query=ctx.query,
metadata_filter=ctx.metadata_filter,
kb_ids=ctx.kb_ids,
)
return await self._retriever.retrieve(retrieval_ctx)
class EnhancedPipeline:
"""
[AC-AISVC-RES-02] Enhanced pipeline with new end-to-end retrieval features.
Features:
- Document preprocessing (cleaning/normalization)
- Structured chunking (markdown/tables/FAQ)
- Metadata generation and mounting
- Embedding strategy (document/query prefix + Matryoshka)
- Metadata inference and filtering (hard/soft filter)
- Retrieval strategy (Dense + Keyword/Hybrid + RRF)
- Optional reranking
"""
def __init__(
self,
config: RoutingConfig | None = None,
):
self._config = config or RoutingConfig()
self._retriever = None
async def execute(
self,
ctx: StrategyContext,
) -> "RetrievalResult":
"""
Execute enhanced retrieval strategy.
Uses OptimizedRetriever with enhanced configuration.
"""
from app.services.retrieval.optimized_retriever import OptimizedRetriever
from app.services.retrieval.base import RetrievalContext
if self._retriever is None:
self._retriever = OptimizedRetriever(
two_stage_enabled=True,
hybrid_enabled=True,
)
retrieval_ctx = RetrievalContext(
tenant_id=ctx.tenant_id,
query=ctx.query,
metadata_filter=ctx.metadata_filter,
kb_ids=ctx.kb_ids,
)
return await self._retriever.retrieve(retrieval_ctx)
class StrategyRouter:
"""
[AC-AISVC-RES-01, AC-AISVC-RES-02, AC-AISVC-RES-03]
Strategy router for retrieval and embedding.
Decision Flow:
1. Check if enhanced strategy is enabled via configuration
2. Check grayscale rules (percentage/allowlist)
3. Route to appropriate pipeline (default/enhanced)
4. Handle rollback on error or performance degradation
Constraints:
- Default strategy MUST preserve existing online logic
- Enhanced strategy MUST be configurable and rollback-able
"""
def __init__(
self,
config: RoutingConfig | None = None,
rollback_manager: RollbackManager | None = None,
):
self._config = config or RoutingConfig()
self._rollback_manager = rollback_manager or RollbackManager()
self._default_pipeline = DefaultPipeline()
self._enhanced_pipeline = EnhancedPipeline(self._config)
self._current_strategy = StrategyType.DEFAULT
self._strategy_enabled = True
@property
def current_strategy(self) -> StrategyType:
"""Get current active strategy."""
return self._current_strategy
@property
def config(self) -> RoutingConfig:
"""Get current configuration."""
return self._config
def update_config(self, new_config: RoutingConfig) -> None:
"""
[AC-AISVC-RES-15] Update routing configuration (hot reload).
Args:
new_config: New configuration to apply
"""
old_strategy = self._config.strategy
self._config = new_config
logger.info(
f"[AC-AISVC-RES-15] Routing config updated: "
f"strategy={new_config.strategy.value}, "
f"mode={new_config.rag_runtime_mode.value}, "
f"grayscale={new_config.grayscale_percentage:.2%}"
)
if old_strategy != new_config.strategy:
logger.info(
f"[AC-AISVC-RES-02] Strategy changed: {old_strategy.value} -> {new_config.strategy.value}"
)
def route(
self,
ctx: StrategyContext,
) -> StrategyResult:
"""
[AC-AISVC-RES-01, AC-AISVC-RES-02] Route to appropriate strategy.
Args:
ctx: Strategy context with tenant, query, metadata, etc.
Returns:
StrategyResult with selected strategy and mode
"""
if not self._strategy_enabled:
logger.info("[AC-AISVC-RES-07] Strategy disabled, using default")
return StrategyResult(
strategy=StrategyType.DEFAULT,
mode=self._config.rag_runtime_mode,
should_fallback=False,
diagnostics={"reason": "strategy_disabled"},
)
use_enhanced = self._config.should_use_enhanced_strategy(ctx.tenant_id)
if use_enhanced:
self._current_strategy = StrategyType.ENHANCED
logger.info(
f"[AC-AISVC-RES-02] Routing to ENHANCED strategy: tenant={ctx.tenant_id}"
)
else:
self._current_strategy = StrategyType.DEFAULT
logger.info(
f"[AC-AISVC-RES-01] Routing to DEFAULT strategy: tenant={ctx.tenant_id}"
)
return StrategyResult(
strategy=self._current_strategy,
mode=self._config.rag_runtime_mode,
diagnostics={
"grayscale_percentage": self._config.grayscale_percentage,
"in_allowlist": ctx.tenant_id in self._config.grayscale_allowlist if ctx.tenant_id else False,
},
)
async def execute(
self,
ctx: StrategyContext,
) -> tuple["RetrievalResult", StrategyResult]:
"""
Execute retrieval with strategy routing.
Args:
ctx: Strategy context
Returns:
Tuple of (RetrievalResult, StrategyResult)
"""
start_time = time.time()
result = self.route(ctx)
try:
if result.strategy == StrategyType.ENHANCED:
retrieval_result = await self._enhanced_pipeline.execute(ctx)
else:
retrieval_result = await self._default_pipeline.execute(ctx)
duration_ms = int((time.time() - start_time) * 1000)
if duration_ms > self._config.performance_budget_ms:
degradation = (duration_ms - self._config.performance_budget_ms) / self._config.performance_budget_ms
if degradation > self._config.performance_degradation_threshold:
logger.warning(
f"[AC-AISVC-RES-08] Performance degradation detected: "
f"duration={duration_ms}ms, budget={self._config.performance_budget_ms}ms, "
f"degradation={degradation:.2%}"
)
return retrieval_result, result
except Exception as e:
logger.error(
f"[AC-AISVC-RES-07] Strategy execution failed: {e}, "
f"strategy={result.strategy.value}"
)
if result.strategy == StrategyType.ENHANCED:
self._rollback_manager.record_rollback(
from_strategy=StrategyType.ENHANCED,
to_strategy=StrategyType.DEFAULT,
reason=str(e),
tenant_id=ctx.tenant_id,
)
logger.info("[AC-AISVC-RES-07] Falling back to DEFAULT strategy")
retrieval_result = await self._default_pipeline.execute(ctx)
return retrieval_result, StrategyResult(
strategy=StrategyType.DEFAULT,
mode=result.mode,
should_fallback=True,
fallback_reason=str(e),
diagnostics=result.diagnostics,
)
raise
def rollback(
self,
reason: str,
tenant_id: str | None = None,
request_id: str | None = None,
) -> None:
"""
[AC-AISVC-RES-07] Force rollback to default strategy.
Args:
reason: Reason for rollback
tenant_id: Optional tenant ID for audit
request_id: Optional request ID for audit
"""
if self._current_strategy == StrategyType.ENHANCED:
self._rollback_manager.record_rollback(
from_strategy=StrategyType.ENHANCED,
to_strategy=StrategyType.DEFAULT,
reason=reason,
tenant_id=tenant_id,
request_id=request_id,
)
self._current_strategy = StrategyType.DEFAULT
self._config.strategy = StrategyType.DEFAULT
logger.info(
f"[AC-AISVC-RES-07] Rollback executed: reason={reason}, tenant={tenant_id}"
)
def get_rollback_records(self, limit: int = 10) -> list[RollbackRecord]:
"""Get recent rollback records."""
return self._rollback_manager.get_recent_rollbacks(limit)
def validate_config(self) -> tuple[bool, list[str]]:
"""
[AC-AISVC-RES-06] Validate current configuration.
Returns:
Tuple of (is_valid, list of error messages)
"""
return self._config.validate()
_strategy_router: StrategyRouter | None = None
def get_strategy_router() -> StrategyRouter:
"""Get or create StrategyRouter singleton."""
global _strategy_router
if _strategy_router is None:
_strategy_router = StrategyRouter()
return _strategy_router
def reset_strategy_router() -> None:
"""Reset StrategyRouter singleton (for testing)."""
global _strategy_router
_strategy_router = None