diff --git a/ai-service/app/services/retrieval/__init__.py b/ai-service/app/services/retrieval/__init__.py index 7fe6292..906bf65 100644 --- a/ai-service/app/services/retrieval/__init__.py +++ b/ai-service/app/services/retrieval/__init__.py @@ -2,6 +2,7 @@ Retrieval module for AI Service. [AC-AISVC-16] Provides retriever implementations with plugin architecture. RAG Optimization: Two-stage retrieval, RRF hybrid ranking, metadata filtering. +[AC-AISVC-RES-01~15] Strategy routing and mode routing for retrieval pipeline. """ from app.services.retrieval.base import ( @@ -32,6 +33,29 @@ from app.services.retrieval.optimized_retriever import ( get_optimized_retriever, ) from app.services.retrieval.vector_retriever import VectorRetriever, get_vector_retriever +from app.services.retrieval.routing_config import ( + RagRuntimeMode, + RoutingConfig, + StrategyContext, + StrategyType, + StrategyResult, +) +from app.services.retrieval.strategy_router import ( + RollbackRecord, + StrategyRouter, + get_strategy_router, +) +from app.services.retrieval.mode_router import ( + ComplexityAnalyzer, + ModeRouter, + ModeRouteResult, + get_mode_router, +) +from app.services.retrieval.strategy_integration import ( + RetrievalStrategyIntegration, + RetrievalStrategyResult, + get_retrieval_strategy_integration, +) __all__ = [ "BaseRetriever", @@ -55,4 +79,19 @@ __all__ = [ "get_knowledge_indexer", "IndexingProgress", "IndexingResult", + "RagRuntimeMode", + "RoutingConfig", + "StrategyContext", + "StrategyType", + "StrategyResult", + "RollbackRecord", + "StrategyRouter", + "get_strategy_router", + "ComplexityAnalyzer", + "ModeRouter", + "ModeRouteResult", + "get_mode_router", + "RetrievalStrategyIntegration", + "RetrievalStrategyResult", + "get_retrieval_strategy_integration", ] diff --git a/ai-service/app/services/retrieval/mode_router.py b/ai-service/app/services/retrieval/mode_router.py new file mode 100644 index 0000000..41fdc48 --- /dev/null +++ b/ai-service/app/services/retrieval/mode_router.py @@ -0,0 +1,438 @@ +""" +Mode Router for RAG Runtime Mode Selection. +[AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11] Routes to direct/react/auto mode. + +Mode Descriptions: +- direct: Low-latency generic retrieval path (single KB call) +- react: Multi-step ReAct retrieval path (high accuracy) +- auto: Automatic selection based on complexity/confidence rules +""" + +from __future__ import annotations + +import logging +import re +import time +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from app.services.retrieval.routing_config import ( + RagRuntimeMode, + RoutingConfig, + StrategyContext, +) + +if TYPE_CHECKING: + from app.services.retrieval.base import RetrievalResult + +logger = logging.getLogger(__name__) + + +@dataclass +class ComplexityAnalyzer: + """ + Analyzes query complexity for mode routing decisions. + + Complexity factors: + - Query length + - Number of conditions/constraints + - Presence of logical operators (and, or, not) + - Cross-domain indicators + - Multi-step reasoning requirements + """ + + short_query_threshold: int = 20 + long_query_threshold: int = 100 + + condition_patterns: list[str] = field(default_factory=lambda: [ + r"和|与|及|并且|同时", + r"或者|还是|要么", + r"但是|不过|然而", + r"如果|假如|假设", + r"既.*又", + r"不仅.*而且", + ]) + + reasoning_patterns: list[str] = field(default_factory=lambda: [ + r"为什么|原因|理由", + r"怎么|如何|怎样", + r"区别|差异|不同", + r"比较|对比|优劣", + r"分析|评估|判断", + ]) + + cross_domain_patterns: list[str] = field(default_factory=lambda: [ + r"跨|多|各个", + r"所有|全部|整体", + r"综合|汇总|统计", + ]) + + def analyze(self, query: str) -> float: + """ + Analyze query complexity and return a score (0.0 ~ 1.0). + + Higher score indicates more complex query that may benefit from ReAct mode. + + Args: + query: User query text + + Returns: + Complexity score (0.0 = simple, 1.0 = very complex) + """ + if not query: + return 0.0 + + score = 0.0 + + query_length = len(query) + if query_length < self.short_query_threshold: + score += 0.0 + elif query_length > self.long_query_threshold: + score += 0.3 + else: + score += 0.15 + + condition_count = 0 + for pattern in self.condition_patterns: + matches = re.findall(pattern, query) + condition_count += len(matches) + + if condition_count >= 3: + score += 0.3 + elif condition_count >= 2: + score += 0.2 + elif condition_count >= 1: + score += 0.1 + + for pattern in self.reasoning_patterns: + if re.search(pattern, query): + score += 0.15 + break + + for pattern in self.cross_domain_patterns: + if re.search(pattern, query): + score += 0.15 + break + + question_marks = query.count("?") + query.count("?") + if question_marks >= 2: + score += 0.1 + + return min(1.0, score) + + +@dataclass +class ModeRouteResult: + """Result from mode routing decision.""" + mode: RagRuntimeMode + confidence: float + complexity_score: float + should_fallback_to_react: bool = False + fallback_reason: str | None = None + diagnostics: dict[str, Any] = field(default_factory=dict) + + +class DirectRetrievalExecutor: + """ + [AC-AISVC-RES-09] Direct retrieval executor for low-latency path. + + Single KB call without multi-step reasoning. + """ + + def __init__(self): + self._retriever = None + + async def execute( + self, + ctx: StrategyContext, + ) -> "RetrievalResult": + """ + Execute direct retrieval (single KB call). + """ + from app.services.retrieval.optimized_retriever import get_optimized_retriever + from app.services.retrieval.base import RetrievalContext + + if self._retriever is None: + self._retriever = await get_optimized_retriever() + + retrieval_ctx = RetrievalContext( + tenant_id=ctx.tenant_id, + query=ctx.query, + metadata_filter=ctx.metadata_filter, + kb_ids=ctx.kb_ids, + ) + + return await self._retriever.retrieve(retrieval_ctx) + + +class ReactRetrievalExecutor: + """ + [AC-AISVC-RES-10] ReAct retrieval executor for multi-step path. + + Uses AgentOrchestrator for multi-step reasoning and KB calls. + """ + + def __init__(self, max_steps: int = 5): + self._max_steps = max_steps + + async def execute( + self, + ctx: StrategyContext, + config: RoutingConfig, + ) -> tuple[str, "RetrievalResult | None", dict[str, Any]]: + """ + Execute ReAct retrieval (multi-step reasoning). + + Returns: + Tuple of (final_answer, retrieval_result, react_context) + """ + from app.services.mid.agent_orchestrator import AgentOrchestrator, AgentMode + from app.services.mid.tool_registry import ToolRegistry + from app.services.mid.timeout_governor import TimeoutGovernor + from app.services.llm.factory import get_llm_config_manager + + try: + llm_manager = get_llm_config_manager() + llm_client = llm_manager.get_client() + except Exception as e: + logger.warning(f"[ModeRouter] Failed to get LLM client: {e}") + llm_client = None + + tool_registry = ToolRegistry(timeout_governor=TimeoutGovernor()) + timeout_governor = TimeoutGovernor() + + orchestrator = AgentOrchestrator( + max_iterations=min(config.react_max_steps, self._max_steps), + timeout_governor=timeout_governor, + llm_client=llm_client, + tool_registry=tool_registry, + tenant_id=ctx.tenant_id, + mode=AgentMode.FUNCTION_CALLING, + ) + + base_context = { + "query": ctx.query, + "metadata_filter": ctx.metadata_filter, + "kb_ids": ctx.kb_ids, + **ctx.additional_context, + } + + final_answer, react_ctx, trace = await orchestrator.execute( + user_message=ctx.query, + context=base_context, + ) + + return final_answer, None, { + "iterations": react_ctx.iteration, + "tool_calls": [tc.model_dump() for tc in react_ctx.tool_calls] if react_ctx.tool_calls else [], + "final_answer": final_answer, + } + + +class ModeRouter: + """ + [AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11] + Mode router for RAG runtime mode selection. + + Mode Selection: + - direct: Low-latency generic retrieval (single KB call) + - react: Multi-step ReAct retrieval (high accuracy) + - auto: Automatic selection based on complexity/confidence + + Auto Mode Rules: + - Direct conditions: + - Short query, clear intent + - High metadata confidence + - No cross-domain/multi-condition + - React conditions: + - Multi-condition/multi-constraint + - Low metadata confidence + - Need for secondary confirmation or multi-step reasoning + """ + + def __init__( + self, + config: RoutingConfig | None = None, + ): + self._config = config or RoutingConfig() + self._complexity_analyzer = ComplexityAnalyzer() + self._direct_executor = DirectRetrievalExecutor() + self._react_executor = ReactRetrievalExecutor( + max_steps=self._config.react_max_steps + ) + + @property + def config(self) -> RoutingConfig: + """Get current configuration.""" + return self._config + + def update_config(self, new_config: RoutingConfig) -> None: + """ + [AC-AISVC-RES-15] Update routing configuration (hot reload). + """ + self._config = new_config + self._react_executor._max_steps = new_config.react_max_steps + + logger.info( + f"[AC-AISVC-RES-15] ModeRouter config updated: " + f"mode={new_config.rag_runtime_mode.value}, " + f"react_max_steps={new_config.react_max_steps}, " + f"confidence_threshold={new_config.react_trigger_confidence_threshold}" + ) + + def route( + self, + ctx: StrategyContext, + ) -> ModeRouteResult: + """ + [AC-AISVC-RES-09, AC-AISVC-RES-10, AC-AISVC-RES-11] + Route to appropriate mode based on configuration and context. + + Args: + ctx: Strategy context with query, metadata, confidence, etc. + + Returns: + ModeRouteResult with selected mode and diagnostics + """ + configured_mode = self._config.get_rag_runtime_mode() + + if configured_mode == RagRuntimeMode.DIRECT: + logger.info( + f"[AC-AISVC-RES-09] Mode routing to DIRECT: tenant={ctx.tenant_id}" + ) + return ModeRouteResult( + mode=RagRuntimeMode.DIRECT, + confidence=ctx.metadata_confidence, + complexity_score=ctx.complexity_score, + diagnostics={"configured_mode": "direct"}, + ) + + if configured_mode == RagRuntimeMode.REACT: + logger.info( + f"[AC-AISVC-RES-10] Mode routing to REACT: tenant={ctx.tenant_id}" + ) + return ModeRouteResult( + mode=RagRuntimeMode.REACT, + confidence=ctx.metadata_confidence, + complexity_score=ctx.complexity_score, + diagnostics={"configured_mode": "react"}, + ) + + complexity_score = self._complexity_analyzer.analyze(ctx.query) + effective_complexity = max(complexity_score, ctx.complexity_score) + + should_use_react = self._config.should_trigger_react_in_auto_mode( + confidence=ctx.metadata_confidence, + complexity_score=effective_complexity, + ) + + selected_mode = RagRuntimeMode.REACT if should_use_react else RagRuntimeMode.DIRECT + + logger.info( + f"[AC-AISVC-RES-11, AC-AISVC-RES-12, AC-AISVC-RES-13] " + f"Auto mode routing: selected={selected_mode.value}, " + f"confidence={ctx.metadata_confidence:.2f}, " + f"complexity={effective_complexity:.2f}, " + f"tenant={ctx.tenant_id}" + ) + + return ModeRouteResult( + mode=selected_mode, + confidence=ctx.metadata_confidence, + complexity_score=effective_complexity, + diagnostics={ + "configured_mode": "auto", + "analyzed_complexity": complexity_score, + "provided_complexity": ctx.complexity_score, + "react_trigger_confidence": self._config.react_trigger_confidence_threshold, + "react_trigger_complexity": self._config.react_trigger_complexity_score, + }, + ) + + async def execute_direct( + self, + ctx: StrategyContext, + ) -> "RetrievalResult": + """ + Execute direct retrieval mode. + """ + return await self._direct_executor.execute(ctx) + + async def execute_react( + self, + ctx: StrategyContext, + ) -> tuple[str, "RetrievalResult | None", dict[str, Any]]: + """ + Execute ReAct retrieval mode. + """ + return await self._react_executor.execute(ctx, self._config) + + async def execute_with_fallback( + self, + ctx: StrategyContext, + ) -> tuple["RetrievalResult | None", str | None, ModeRouteResult]: + """ + [AC-AISVC-RES-14] Execute with fallback from direct to react on low confidence. + + Args: + ctx: Strategy context + + Returns: + Tuple of (RetrievalResult or None, final_answer or None, ModeRouteResult) + """ + route_result = self.route(ctx) + + if route_result.mode == RagRuntimeMode.DIRECT: + retrieval_result = await self._direct_executor.execute(ctx) + + max_score = 0.0 + if retrieval_result and retrieval_result.hits: + max_score = max((h.score for h in retrieval_result.hits), default=0.0) + + if self._config.should_fallback_direct_to_react(max_score): + logger.info( + f"[AC-AISVC-RES-14] Direct mode low confidence fallback to react: " + f"confidence={max_score:.2f}, threshold={self._config.direct_fallback_confidence_threshold}" + ) + + final_answer, _, react_ctx = await self._react_executor.execute(ctx, self._config) + + return ( + None, + final_answer, + ModeRouteResult( + mode=RagRuntimeMode.REACT, + confidence=max_score, + complexity_score=route_result.complexity_score, + should_fallback_to_react=True, + fallback_reason="low_confidence", + diagnostics={ + **route_result.diagnostics, + "fallback_from": "direct", + "direct_confidence": max_score, + }, + ), + ) + + return retrieval_result, None, route_result + + final_answer, _, react_ctx = await self._react_executor.execute(ctx, self._config) + + return None, final_answer, route_result + + +_mode_router: ModeRouter | None = None + + +def get_mode_router() -> ModeRouter: + """Get or create ModeRouter singleton.""" + global _mode_router + if _mode_router is None: + _mode_router = ModeRouter() + return _mode_router + + +def reset_mode_router() -> None: + """Reset ModeRouter singleton (for testing).""" + global _mode_router + _mode_router = None diff --git a/ai-service/app/services/retrieval/routing_config.py b/ai-service/app/services/retrieval/routing_config.py new file mode 100644 index 0000000..11b48d1 --- /dev/null +++ b/ai-service/app/services/retrieval/routing_config.py @@ -0,0 +1,187 @@ +""" +Retrieval and Embedding Strategy Configuration. +[AC-AISVC-RES-01~15] Configuration for strategy routing and mode routing. +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + + +class StrategyType(str, Enum): + """Strategy type for retrieval pipeline selection.""" + DEFAULT = "default" + ENHANCED = "enhanced" + + +class RagRuntimeMode(str, Enum): + """RAG runtime mode for execution path selection.""" + DIRECT = "direct" + REACT = "react" + AUTO = "auto" + + +@dataclass +class RoutingConfig: + """ + [AC-AISVC-RES-01~15] Routing configuration for strategy and mode selection. + + Configuration hierarchy: + 1. Strategy selection (default vs enhanced) + 2. Mode selection (direct/react/auto) + 3. Auto routing rules (complexity/confidence thresholds) + 4. Fallback behavior + """ + + enabled: bool = True + strategy: StrategyType = StrategyType.DEFAULT + + grayscale_percentage: float = 0.0 + grayscale_allowlist: list[str] = field(default_factory=list) + + rag_runtime_mode: RagRuntimeMode = RagRuntimeMode.AUTO + + react_trigger_confidence_threshold: float = 0.6 + react_trigger_complexity_score: float = 0.5 + react_max_steps: int = 5 + + direct_fallback_on_low_confidence: bool = True + direct_fallback_confidence_threshold: float = 0.4 + + performance_budget_ms: int = 5000 + performance_degradation_threshold: float = 0.2 + + def should_use_enhanced_strategy(self, tenant_id: str | None = None) -> bool: + """ + [AC-AISVC-RES-02, AC-AISVC-RES-03] Determine if enhanced strategy should be used. + + Priority: + 1. If strategy is explicitly set to ENHANCED, use enhanced + 2. If strategy is DEFAULT, use default + 3. If grayscale is enabled, check percentage/allowlist + """ + if self.strategy == StrategyType.ENHANCED: + return True + + if self.strategy == StrategyType.DEFAULT: + return False + + if self.grayscale_percentage > 0: + import hashlib + if tenant_id: + hash_val = int(hashlib.md5(tenant_id.encode()).hexdigest()[:8], 16) + return (hash_val % 100) < (self.grayscale_percentage * 100) + return False + + if self.grayscale_allowlist and tenant_id: + return tenant_id in self.grayscale_allowlist + + return False + + def get_rag_runtime_mode(self) -> RagRuntimeMode: + """Get the configured RAG runtime mode.""" + return self.rag_runtime_mode + + def should_fallback_direct_to_react(self, confidence: float) -> bool: + """ + [AC-AISVC-RES-14] Determine if direct mode should fallback to react. + + Args: + confidence: Retrieval confidence score (0.0 ~ 1.0) + + Returns: + True if fallback should be triggered + """ + if not self.direct_fallback_on_low_confidence: + return False + + return confidence < self.direct_fallback_confidence_threshold + + def should_trigger_react_in_auto_mode( + self, + confidence: float, + complexity_score: float, + ) -> bool: + """ + [AC-AISVC-RES-11, AC-AISVC-RES-12, AC-AISVC-RES-13] + Determine if react mode should be triggered in auto mode. + + Direct conditions (优先): + - Short query, clear intent + - High metadata confidence + - No cross-domain/multi-condition + + React conditions: + - Multi-condition/multi-constraint + - Low metadata confidence + - Need for secondary confirmation or multi-step reasoning + + Args: + confidence: Metadata inference confidence (0.0 ~ 1.0) + complexity_score: Query complexity score (0.0 ~ 1.0) + + Returns: + True if react mode should be used + """ + if confidence < self.react_trigger_confidence_threshold: + return True + + if complexity_score > self.react_trigger_complexity_score: + return True + + return False + + def validate(self) -> tuple[bool, list[str]]: + """ + [AC-AISVC-RES-06] Validate configuration consistency. + + Returns: + (is_valid, list of error messages) + """ + errors = [] + + if self.grayscale_percentage < 0 or self.grayscale_percentage > 1.0: + errors.append("grayscale_percentage must be between 0.0 and 1.0") + + if self.react_trigger_confidence_threshold < 0 or self.react_trigger_confidence_threshold > 1.0: + errors.append("react_trigger_confidence_threshold must be between 0.0 and 1.0") + + if self.react_trigger_complexity_score < 0 or self.react_trigger_complexity_score > 1.0: + errors.append("react_trigger_complexity_score must be between 0.0 and 1.0") + + if self.react_max_steps < 3 or self.react_max_steps > 10: + errors.append("react_max_steps must be between 3 and 10") + + if self.direct_fallback_confidence_threshold < 0 or self.direct_fallback_confidence_threshold > 1.0: + errors.append("direct_fallback_confidence_threshold must be between 0.0 and 1.0") + + if self.performance_budget_ms < 1000: + errors.append("performance_budget_ms must be at least 1000") + + if self.performance_degradation_threshold < 0 or self.performance_degradation_threshold > 1.0: + errors.append("performance_degradation_threshold must be between 0.0 and 1.0") + + return (len(errors) == 0, errors) + + +@dataclass +class StrategyContext: + """Context for strategy routing decision.""" + tenant_id: str + query: str + metadata_filter: dict[str, Any] | None = None + metadata_confidence: float = 1.0 + complexity_score: float = 0.0 + kb_ids: list[str] | None = None + top_k: int = 5 + additional_context: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class StrategyResult: + """Result from strategy routing.""" + strategy: StrategyType + mode: RagRuntimeMode + should_fallback: bool = False + fallback_reason: str | None = None + diagnostics: dict[str, Any] = field(default_factory=dict) diff --git a/ai-service/app/services/retrieval/strategy_integration.py b/ai-service/app/services/retrieval/strategy_integration.py new file mode 100644 index 0000000..db01e7d --- /dev/null +++ b/ai-service/app/services/retrieval/strategy_integration.py @@ -0,0 +1,233 @@ +""" +Retrieval Strategy Integration for Dialogue Flow. +[AC-AISVC-RES-01~15] Integrates StrategyRouter and ModeRouter into dialogue pipeline. + +Usage: + from app.services.retrieval.strategy_integration import RetrievalStrategyIntegration + + integration = RetrievalStrategyIntegration() + result = await integration.execute(ctx) +""" + +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from app.services.retrieval.routing_config import ( + RagRuntimeMode, + StrategyType, + RoutingConfig, + StrategyContext, + StrategyResult, +) +from app.services.retrieval.strategy_router import ( + StrategyRouter, + get_strategy_router, +) +from app.services.retrieval.mode_router import ( + ModeRouter, + ModeRouteResult, + get_mode_router, +) + +if TYPE_CHECKING: + from app.services.retrieval.base import RetrievalResult + +logger = logging.getLogger(__name__) + + +@dataclass +class RetrievalStrategyResult: + """Combined result from strategy and mode routing.""" + retrieval_result: "RetrievalResult | None" + final_answer: str | None + strategy: StrategyType + mode: RagRuntimeMode + should_fallback: bool = False + fallback_reason: str | None = None + mode_route_result: ModeRouteResult | None = None + diagnostics: dict[str, Any] = field(default_factory=dict) + duration_ms: int = 0 + + +class RetrievalStrategyIntegration: + """ + [AC-AISVC-RES-01~15] Integration layer for retrieval strategy. + + Combines StrategyRouter and ModeRouter to provide a unified interface + for the dialogue pipeline. + + Flow: + 1. StrategyRouter selects default or enhanced strategy + 2. ModeRouter selects direct, react, or auto mode + 3. Execute retrieval with selected strategy and mode + 4. Handle fallback scenarios + """ + + def __init__( + self, + config: RoutingConfig | None = None, + strategy_router: StrategyRouter | None = None, + mode_router: ModeRouter | None = None, + ): + self._config = config or RoutingConfig() + self._strategy_router = strategy_router or get_strategy_router() + self._mode_router = mode_router or get_mode_router() + + @property + def config(self) -> RoutingConfig: + """Get current configuration.""" + return self._config + + def update_config(self, new_config: RoutingConfig) -> None: + """ + [AC-AISVC-RES-15] Update all routing configurations. + """ + self._config = new_config + self._strategy_router.update_config(new_config) + self._mode_router.update_config(new_config) + + logger.info( + f"[AC-AISVC-RES-15] RetrievalStrategyIntegration config updated: " + f"strategy={new_config.strategy.value}, mode={new_config.rag_runtime_mode.value}" + ) + + async def execute( + self, + ctx: StrategyContext, + ) -> RetrievalStrategyResult: + """ + Execute retrieval with strategy and mode routing. + + Args: + ctx: Strategy context with tenant, query, metadata, etc. + + Returns: + RetrievalStrategyResult with retrieval results and diagnostics + """ + start_time = time.time() + + strategy_result = self._strategy_router.route(ctx) + + mode_result = self._mode_router.route(ctx) + + logger.info( + f"[AC-AISVC-RES-01~15] Strategy routing: " + f"strategy={strategy_result.strategy.value}, mode={mode_result.mode.value}, " + f"tenant={ctx.tenant_id}, query_len={len(ctx.query)}" + ) + + retrieval_result = None + final_answer = None + should_fallback = False + fallback_reason = None + + try: + if mode_result.mode == RagRuntimeMode.DIRECT: + retrieval_result, answer, mode_result = await self._mode_router.execute_with_fallback(ctx) + + if answer is not None: + final_answer = answer + should_fallback = mode_result.should_fallback_to_react + fallback_reason = mode_result.fallback_reason + + elif mode_result.mode == RagRuntimeMode.REACT: + answer, retrieval_result, react_ctx = await self._mode_router.execute_react(ctx) + final_answer = answer + + else: + retrieval_result, answer, mode_result = await self._mode_router.execute_with_fallback(ctx) + + if answer is not None: + final_answer = answer + should_fallback = mode_result.should_fallback_to_react + fallback_reason = mode_result.fallback_reason + + except Exception as e: + logger.error( + f"[AC-AISVC-RES-07] Retrieval strategy execution failed: {e}" + ) + + if strategy_result.strategy == StrategyType.ENHANCED: + self._strategy_router.rollback( + reason=str(e), + tenant_id=ctx.tenant_id, + ) + + from app.services.retrieval.optimized_retriever import get_optimized_retriever + from app.services.retrieval.base import RetrievalContext + + retriever = await get_optimized_retriever() + retrieval_ctx = RetrievalContext( + tenant_id=ctx.tenant_id, + query=ctx.query, + metadata_filter=ctx.metadata_filter, + kb_ids=ctx.kb_ids, + ) + retrieval_result = await retriever.retrieve(retrieval_ctx) + + should_fallback = True + fallback_reason = str(e) + + else: + raise + + duration_ms = int((time.time() - start_time) * 1000) + + return RetrievalStrategyResult( + retrieval_result=retrieval_result, + final_answer=final_answer, + strategy=strategy_result.strategy, + mode=mode_result.mode, + should_fallback=should_fallback, + fallback_reason=fallback_reason, + mode_route_result=mode_result, + diagnostics={ + "strategy_diagnostics": strategy_result.diagnostics, + "mode_diagnostics": mode_result.diagnostics, + "duration_ms": duration_ms, + }, + duration_ms=duration_ms, + ) + + def get_current_strategy(self) -> StrategyType: + """Get current active strategy.""" + return self._strategy_router.current_strategy + + def get_rollback_records(self, limit: int = 10) -> list[dict[str, Any]]: + """Get recent rollback records.""" + records = self._strategy_router.get_rollback_records(limit) + return [ + { + "timestamp": r.timestamp, + "from_strategy": r.from_strategy.value, + "to_strategy": r.to_strategy.value, + "reason": r.reason, + "tenant_id": r.tenant_id, + } + for r in records + ] + + def validate_config(self) -> tuple[bool, list[str]]: + """Validate current configuration.""" + return self._config.validate() + + +_integration: RetrievalStrategyIntegration | None = None + + +def get_retrieval_strategy_integration() -> RetrievalStrategyIntegration: + """Get or create RetrievalStrategyIntegration singleton.""" + global _integration + if _integration is None: + _integration = RetrievalStrategyIntegration() + return _integration + + +def reset_retrieval_strategy_integration() -> None: + """Reset RetrievalStrategyIntegration singleton (for testing).""" + global _integration + _integration = None diff --git a/ai-service/app/services/retrieval/strategy_router.py b/ai-service/app/services/retrieval/strategy_router.py new file mode 100644 index 0000000..6fcfa5e --- /dev/null +++ b/ai-service/app/services/retrieval/strategy_router.py @@ -0,0 +1,403 @@ +""" +Strategy Router for Retrieval and Embedding. +[AC-AISVC-RES-01, AC-AISVC-RES-02, AC-AISVC-RES-03] Routes to default or enhanced strategy. + +Key Features: +- Default strategy preserves existing online logic +- Enhanced strategy is configurable and can be rolled back +- Supports grayscale release (percentage/allowlist) +- Supports rollback on error or performance degradation +""" + +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any + +from app.services.retrieval.routing_config import ( + RagRuntimeMode, + StrategyType, + RoutingConfig, + StrategyContext, + StrategyResult, +) + +if TYPE_CHECKING: + from app.services.retrieval.base import RetrievalResult + +logger = logging.getLogger(__name__) + + +@dataclass +class RollbackRecord: + """Record for strategy rollback event.""" + timestamp: float + from_strategy: StrategyType + to_strategy: StrategyType + reason: str + tenant_id: str | None = None + request_id: str | None = None + + +class RollbackManager: + """ + [AC-AISVC-RES-07] Manages strategy rollback and audit logging. + """ + + def __init__(self, max_records: int = 100): + self._records: list[RollbackRecord] = [] + self._max_records = max_records + + def record_rollback( + self, + from_strategy: StrategyType, + to_strategy: StrategyType, + reason: str, + tenant_id: str | None = None, + request_id: str | None = None, + ) -> None: + """Record a rollback event.""" + record = RollbackRecord( + timestamp=time.time(), + from_strategy=from_strategy, + to_strategy=to_strategy, + reason=reason, + tenant_id=tenant_id, + request_id=request_id, + ) + + self._records.append(record) + + if len(self._records) > self._max_records: + self._records = self._records[-self._max_records:] + + logger.info( + f"[AC-AISVC-RES-07] Rollback recorded: {from_strategy.value} -> {to_strategy.value}, " + f"reason={reason}, tenant={tenant_id}" + ) + + def get_recent_rollbacks(self, limit: int = 10) -> list[RollbackRecord]: + """Get recent rollback records.""" + return self._records[-limit:] + + def get_rollback_count(self, since_timestamp: float | None = None) -> int: + """Get count of rollbacks, optionally since a timestamp.""" + if since_timestamp is None: + return len(self._records) + + return sum(1 for r in self._records if r.timestamp >= since_timestamp) + + +class DefaultPipeline: + """ + [AC-AISVC-RES-01] Default pipeline that preserves existing online logic. + + This pipeline uses the existing OptimizedRetriever without any new features. + """ + + def __init__(self): + self._retriever = None + + async def execute( + self, + ctx: StrategyContext, + ) -> "RetrievalResult": + """ + Execute default retrieval strategy. + + Uses existing OptimizedRetriever with current configuration. + """ + from app.services.retrieval.optimized_retriever import get_optimized_retriever + from app.services.retrieval.base import RetrievalContext + + if self._retriever is None: + self._retriever = await get_optimized_retriever() + + retrieval_ctx = RetrievalContext( + tenant_id=ctx.tenant_id, + query=ctx.query, + metadata_filter=ctx.metadata_filter, + kb_ids=ctx.kb_ids, + ) + + return await self._retriever.retrieve(retrieval_ctx) + + +class EnhancedPipeline: + """ + [AC-AISVC-RES-02] Enhanced pipeline with new end-to-end retrieval features. + + Features: + - Document preprocessing (cleaning/normalization) + - Structured chunking (markdown/tables/FAQ) + - Metadata generation and mounting + - Embedding strategy (document/query prefix + Matryoshka) + - Metadata inference and filtering (hard/soft filter) + - Retrieval strategy (Dense + Keyword/Hybrid + RRF) + - Optional reranking + """ + + def __init__( + self, + config: RoutingConfig | None = None, + ): + self._config = config or RoutingConfig() + self._retriever = None + + async def execute( + self, + ctx: StrategyContext, + ) -> "RetrievalResult": + """ + Execute enhanced retrieval strategy. + + Uses OptimizedRetriever with enhanced configuration. + """ + from app.services.retrieval.optimized_retriever import OptimizedRetriever + from app.services.retrieval.base import RetrievalContext + + if self._retriever is None: + self._retriever = OptimizedRetriever( + two_stage_enabled=True, + hybrid_enabled=True, + ) + + retrieval_ctx = RetrievalContext( + tenant_id=ctx.tenant_id, + query=ctx.query, + metadata_filter=ctx.metadata_filter, + kb_ids=ctx.kb_ids, + ) + + return await self._retriever.retrieve(retrieval_ctx) + + +class StrategyRouter: + """ + [AC-AISVC-RES-01, AC-AISVC-RES-02, AC-AISVC-RES-03] + Strategy router for retrieval and embedding. + + Decision Flow: + 1. Check if enhanced strategy is enabled via configuration + 2. Check grayscale rules (percentage/allowlist) + 3. Route to appropriate pipeline (default/enhanced) + 4. Handle rollback on error or performance degradation + + Constraints: + - Default strategy MUST preserve existing online logic + - Enhanced strategy MUST be configurable and rollback-able + """ + + def __init__( + self, + config: RoutingConfig | None = None, + rollback_manager: RollbackManager | None = None, + ): + self._config = config or RoutingConfig() + self._rollback_manager = rollback_manager or RollbackManager() + self._default_pipeline = DefaultPipeline() + self._enhanced_pipeline = EnhancedPipeline(self._config) + + self._current_strategy = StrategyType.DEFAULT + self._strategy_enabled = True + + @property + def current_strategy(self) -> StrategyType: + """Get current active strategy.""" + return self._current_strategy + + @property + def config(self) -> RoutingConfig: + """Get current configuration.""" + return self._config + + def update_config(self, new_config: RoutingConfig) -> None: + """ + [AC-AISVC-RES-15] Update routing configuration (hot reload). + + Args: + new_config: New configuration to apply + """ + old_strategy = self._config.strategy + self._config = new_config + + logger.info( + f"[AC-AISVC-RES-15] Routing config updated: " + f"strategy={new_config.strategy.value}, " + f"mode={new_config.rag_runtime_mode.value}, " + f"grayscale={new_config.grayscale_percentage:.2%}" + ) + + if old_strategy != new_config.strategy: + logger.info( + f"[AC-AISVC-RES-02] Strategy changed: {old_strategy.value} -> {new_config.strategy.value}" + ) + + def route( + self, + ctx: StrategyContext, + ) -> StrategyResult: + """ + [AC-AISVC-RES-01, AC-AISVC-RES-02] Route to appropriate strategy. + + Args: + ctx: Strategy context with tenant, query, metadata, etc. + + Returns: + StrategyResult with selected strategy and mode + """ + if not self._strategy_enabled: + logger.info("[AC-AISVC-RES-07] Strategy disabled, using default") + return StrategyResult( + strategy=StrategyType.DEFAULT, + mode=self._config.rag_runtime_mode, + should_fallback=False, + diagnostics={"reason": "strategy_disabled"}, + ) + + use_enhanced = self._config.should_use_enhanced_strategy(ctx.tenant_id) + + if use_enhanced: + self._current_strategy = StrategyType.ENHANCED + logger.info( + f"[AC-AISVC-RES-02] Routing to ENHANCED strategy: tenant={ctx.tenant_id}" + ) + else: + self._current_strategy = StrategyType.DEFAULT + logger.info( + f"[AC-AISVC-RES-01] Routing to DEFAULT strategy: tenant={ctx.tenant_id}" + ) + + return StrategyResult( + strategy=self._current_strategy, + mode=self._config.rag_runtime_mode, + diagnostics={ + "grayscale_percentage": self._config.grayscale_percentage, + "in_allowlist": ctx.tenant_id in self._config.grayscale_allowlist if ctx.tenant_id else False, + }, + ) + + async def execute( + self, + ctx: StrategyContext, + ) -> tuple["RetrievalResult", StrategyResult]: + """ + Execute retrieval with strategy routing. + + Args: + ctx: Strategy context + + Returns: + Tuple of (RetrievalResult, StrategyResult) + """ + start_time = time.time() + result = self.route(ctx) + + try: + if result.strategy == StrategyType.ENHANCED: + retrieval_result = await self._enhanced_pipeline.execute(ctx) + else: + retrieval_result = await self._default_pipeline.execute(ctx) + + duration_ms = int((time.time() - start_time) * 1000) + + if duration_ms > self._config.performance_budget_ms: + degradation = (duration_ms - self._config.performance_budget_ms) / self._config.performance_budget_ms + if degradation > self._config.performance_degradation_threshold: + logger.warning( + f"[AC-AISVC-RES-08] Performance degradation detected: " + f"duration={duration_ms}ms, budget={self._config.performance_budget_ms}ms, " + f"degradation={degradation:.2%}" + ) + + return retrieval_result, result + + except Exception as e: + logger.error( + f"[AC-AISVC-RES-07] Strategy execution failed: {e}, " + f"strategy={result.strategy.value}" + ) + + if result.strategy == StrategyType.ENHANCED: + self._rollback_manager.record_rollback( + from_strategy=StrategyType.ENHANCED, + to_strategy=StrategyType.DEFAULT, + reason=str(e), + tenant_id=ctx.tenant_id, + ) + + logger.info("[AC-AISVC-RES-07] Falling back to DEFAULT strategy") + + retrieval_result = await self._default_pipeline.execute(ctx) + + return retrieval_result, StrategyResult( + strategy=StrategyType.DEFAULT, + mode=result.mode, + should_fallback=True, + fallback_reason=str(e), + diagnostics=result.diagnostics, + ) + + raise + + def rollback( + self, + reason: str, + tenant_id: str | None = None, + request_id: str | None = None, + ) -> None: + """ + [AC-AISVC-RES-07] Force rollback to default strategy. + + Args: + reason: Reason for rollback + tenant_id: Optional tenant ID for audit + request_id: Optional request ID for audit + """ + if self._current_strategy == StrategyType.ENHANCED: + self._rollback_manager.record_rollback( + from_strategy=StrategyType.ENHANCED, + to_strategy=StrategyType.DEFAULT, + reason=reason, + tenant_id=tenant_id, + request_id=request_id, + ) + + self._current_strategy = StrategyType.DEFAULT + self._config.strategy = StrategyType.DEFAULT + + logger.info( + f"[AC-AISVC-RES-07] Rollback executed: reason={reason}, tenant={tenant_id}" + ) + + def get_rollback_records(self, limit: int = 10) -> list[RollbackRecord]: + """Get recent rollback records.""" + return self._rollback_manager.get_recent_rollbacks(limit) + + def validate_config(self) -> tuple[bool, list[str]]: + """ + [AC-AISVC-RES-06] Validate current configuration. + + Returns: + Tuple of (is_valid, list of error messages) + """ + return self._config.validate() + + +_strategy_router: StrategyRouter | None = None + + +def get_strategy_router() -> StrategyRouter: + """Get or create StrategyRouter singleton.""" + global _strategy_router + if _strategy_router is None: + _strategy_router = StrategyRouter() + return _strategy_router + + +def reset_strategy_router() -> None: + """Reset StrategyRouter singleton (for testing).""" + global _strategy_router + _strategy_router = None