""" Strategy Router for Retrieval and Embedding. [AC-AISVC-RES-01, AC-AISVC-RES-02, AC-AISVC-RES-03] Routes to default or enhanced strategy. Key Features: - Default strategy preserves existing online logic - Enhanced strategy is configurable and can be rolled back - Supports grayscale release (percentage/allowlist) - Supports rollback on error or performance degradation """ from __future__ import annotations import logging import time from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any from app.services.retrieval.routing_config import ( RagRuntimeMode, StrategyType, RoutingConfig, StrategyContext, StrategyResult, ) if TYPE_CHECKING: from app.services.retrieval.base import RetrievalResult logger = logging.getLogger(__name__) @dataclass class RollbackRecord: """Record for strategy rollback event.""" timestamp: float from_strategy: StrategyType to_strategy: StrategyType reason: str tenant_id: str | None = None request_id: str | None = None class RollbackManager: """ [AC-AISVC-RES-07] Manages strategy rollback and audit logging. """ def __init__(self, max_records: int = 100): self._records: list[RollbackRecord] = [] self._max_records = max_records def record_rollback( self, from_strategy: StrategyType, to_strategy: StrategyType, reason: str, tenant_id: str | None = None, request_id: str | None = None, ) -> None: """Record a rollback event.""" record = RollbackRecord( timestamp=time.time(), from_strategy=from_strategy, to_strategy=to_strategy, reason=reason, tenant_id=tenant_id, request_id=request_id, ) self._records.append(record) if len(self._records) > self._max_records: self._records = self._records[-self._max_records:] logger.info( f"[AC-AISVC-RES-07] Rollback recorded: {from_strategy.value} -> {to_strategy.value}, " f"reason={reason}, tenant={tenant_id}" ) def get_recent_rollbacks(self, limit: int = 10) -> list[RollbackRecord]: """Get recent rollback records.""" return self._records[-limit:] def get_rollback_count(self, since_timestamp: float | None = None) -> int: """Get count of rollbacks, optionally since a timestamp.""" if since_timestamp is None: return len(self._records) return sum(1 for r in self._records if r.timestamp >= since_timestamp) class DefaultPipeline: """ [AC-AISVC-RES-01] Default pipeline that preserves existing online logic. This pipeline uses the existing OptimizedRetriever without any new features. """ def __init__(self): self._retriever = None async def execute( self, ctx: StrategyContext, ) -> "RetrievalResult": """ Execute default retrieval strategy. Uses existing OptimizedRetriever with current configuration. """ from app.services.retrieval.optimized_retriever import get_optimized_retriever from app.services.retrieval.base import RetrievalContext if self._retriever is None: self._retriever = await get_optimized_retriever() retrieval_ctx = RetrievalContext( tenant_id=ctx.tenant_id, query=ctx.query, metadata_filter=ctx.metadata_filter, kb_ids=ctx.kb_ids, ) return await self._retriever.retrieve(retrieval_ctx) class EnhancedPipeline: """ [AC-AISVC-RES-02] Enhanced pipeline with new end-to-end retrieval features. Features: - Document preprocessing (cleaning/normalization) - Structured chunking (markdown/tables/FAQ) - Metadata generation and mounting - Embedding strategy (document/query prefix + Matryoshka) - Metadata inference and filtering (hard/soft filter) - Retrieval strategy (Dense + Keyword/Hybrid + RRF) - Optional reranking """ def __init__( self, config: RoutingConfig | None = None, ): self._config = config or RoutingConfig() self._retriever = None async def execute( self, ctx: StrategyContext, ) -> "RetrievalResult": """ Execute enhanced retrieval strategy. Uses OptimizedRetriever with enhanced configuration. """ from app.services.retrieval.optimized_retriever import OptimizedRetriever from app.services.retrieval.base import RetrievalContext if self._retriever is None: self._retriever = OptimizedRetriever( two_stage_enabled=True, hybrid_enabled=True, ) retrieval_ctx = RetrievalContext( tenant_id=ctx.tenant_id, query=ctx.query, metadata_filter=ctx.metadata_filter, kb_ids=ctx.kb_ids, ) return await self._retriever.retrieve(retrieval_ctx) class StrategyRouter: """ [AC-AISVC-RES-01, AC-AISVC-RES-02, AC-AISVC-RES-03] Strategy router for retrieval and embedding. Decision Flow: 1. Check if enhanced strategy is enabled via configuration 2. Check grayscale rules (percentage/allowlist) 3. Route to appropriate pipeline (default/enhanced) 4. Handle rollback on error or performance degradation Constraints: - Default strategy MUST preserve existing online logic - Enhanced strategy MUST be configurable and rollback-able """ def __init__( self, config: RoutingConfig | None = None, rollback_manager: RollbackManager | None = None, ): self._config = config or RoutingConfig() self._rollback_manager = rollback_manager or RollbackManager() self._default_pipeline = DefaultPipeline() self._enhanced_pipeline = EnhancedPipeline(self._config) self._current_strategy = StrategyType.DEFAULT self._strategy_enabled = True @property def current_strategy(self) -> StrategyType: """Get current active strategy.""" return self._current_strategy @property def config(self) -> RoutingConfig: """Get current configuration.""" return self._config def update_config(self, new_config: RoutingConfig) -> None: """ [AC-AISVC-RES-15] Update routing configuration (hot reload). Args: new_config: New configuration to apply """ old_strategy = self._config.strategy self._config = new_config logger.info( f"[AC-AISVC-RES-15] Routing config updated: " f"strategy={new_config.strategy.value}, " f"mode={new_config.rag_runtime_mode.value}, " f"grayscale={new_config.grayscale_percentage:.2%}" ) if old_strategy != new_config.strategy: logger.info( f"[AC-AISVC-RES-02] Strategy changed: {old_strategy.value} -> {new_config.strategy.value}" ) def route( self, ctx: StrategyContext, ) -> StrategyResult: """ [AC-AISVC-RES-01, AC-AISVC-RES-02] Route to appropriate strategy. Args: ctx: Strategy context with tenant, query, metadata, etc. Returns: StrategyResult with selected strategy and mode """ if not self._strategy_enabled: logger.info("[AC-AISVC-RES-07] Strategy disabled, using default") return StrategyResult( strategy=StrategyType.DEFAULT, mode=self._config.rag_runtime_mode, should_fallback=False, diagnostics={"reason": "strategy_disabled"}, ) use_enhanced = self._config.should_use_enhanced_strategy(ctx.tenant_id) if use_enhanced: self._current_strategy = StrategyType.ENHANCED logger.info( f"[AC-AISVC-RES-02] Routing to ENHANCED strategy: tenant={ctx.tenant_id}" ) else: self._current_strategy = StrategyType.DEFAULT logger.info( f"[AC-AISVC-RES-01] Routing to DEFAULT strategy: tenant={ctx.tenant_id}" ) return StrategyResult( strategy=self._current_strategy, mode=self._config.rag_runtime_mode, diagnostics={ "grayscale_percentage": self._config.grayscale_percentage, "in_allowlist": ctx.tenant_id in self._config.grayscale_allowlist if ctx.tenant_id else False, }, ) async def execute( self, ctx: StrategyContext, ) -> tuple["RetrievalResult", StrategyResult]: """ Execute retrieval with strategy routing. Args: ctx: Strategy context Returns: Tuple of (RetrievalResult, StrategyResult) """ start_time = time.time() result = self.route(ctx) try: if result.strategy == StrategyType.ENHANCED: retrieval_result = await self._enhanced_pipeline.execute(ctx) else: retrieval_result = await self._default_pipeline.execute(ctx) duration_ms = int((time.time() - start_time) * 1000) if duration_ms > self._config.performance_budget_ms: degradation = (duration_ms - self._config.performance_budget_ms) / self._config.performance_budget_ms if degradation > self._config.performance_degradation_threshold: logger.warning( f"[AC-AISVC-RES-08] Performance degradation detected: " f"duration={duration_ms}ms, budget={self._config.performance_budget_ms}ms, " f"degradation={degradation:.2%}" ) return retrieval_result, result except Exception as e: logger.error( f"[AC-AISVC-RES-07] Strategy execution failed: {e}, " f"strategy={result.strategy.value}" ) if result.strategy == StrategyType.ENHANCED: self._rollback_manager.record_rollback( from_strategy=StrategyType.ENHANCED, to_strategy=StrategyType.DEFAULT, reason=str(e), tenant_id=ctx.tenant_id, ) logger.info("[AC-AISVC-RES-07] Falling back to DEFAULT strategy") retrieval_result = await self._default_pipeline.execute(ctx) return retrieval_result, StrategyResult( strategy=StrategyType.DEFAULT, mode=result.mode, should_fallback=True, fallback_reason=str(e), diagnostics=result.diagnostics, ) raise def rollback( self, reason: str, tenant_id: str | None = None, request_id: str | None = None, ) -> None: """ [AC-AISVC-RES-07] Force rollback to default strategy. Args: reason: Reason for rollback tenant_id: Optional tenant ID for audit request_id: Optional request ID for audit """ if self._current_strategy == StrategyType.ENHANCED: self._rollback_manager.record_rollback( from_strategy=StrategyType.ENHANCED, to_strategy=StrategyType.DEFAULT, reason=reason, tenant_id=tenant_id, request_id=request_id, ) self._current_strategy = StrategyType.DEFAULT self._config.strategy = StrategyType.DEFAULT logger.info( f"[AC-AISVC-RES-07] Rollback executed: reason={reason}, tenant={tenant_id}" ) def get_rollback_records(self, limit: int = 10) -> list[RollbackRecord]: """Get recent rollback records.""" return self._rollback_manager.get_recent_rollbacks(limit) def validate_config(self) -> tuple[bool, list[str]]: """ [AC-AISVC-RES-06] Validate current configuration. Returns: Tuple of (is_valid, list of error messages) """ return self._config.validate() _strategy_router: StrategyRouter | None = None def get_strategy_router() -> StrategyRouter: """Get or create StrategyRouter singleton.""" global _strategy_router if _strategy_router is None: _strategy_router = StrategyRouter() return _strategy_router def reset_strategy_router() -> None: """Reset StrategyRouter singleton (for testing).""" global _strategy_router _strategy_router = None