ai-robot-core/ai-service/app/services/retrieval/strategy_router.py

404 lines
13 KiB
Python

"""
Strategy Router for Retrieval and Embedding.
[AC-AISVC-RES-01, AC-AISVC-RES-02, AC-AISVC-RES-03] Routes to default or enhanced strategy.
Key Features:
- Default strategy preserves existing online logic
- Enhanced strategy is configurable and can be rolled back
- Supports grayscale release (percentage/allowlist)
- Supports rollback on error or performance degradation
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from app.services.retrieval.routing_config import (
RagRuntimeMode,
StrategyType,
RoutingConfig,
StrategyContext,
StrategyResult,
)
if TYPE_CHECKING:
from app.services.retrieval.base import RetrievalResult
logger = logging.getLogger(__name__)
@dataclass
class RollbackRecord:
"""Record for strategy rollback event."""
timestamp: float
from_strategy: StrategyType
to_strategy: StrategyType
reason: str
tenant_id: str | None = None
request_id: str | None = None
class RollbackManager:
"""
[AC-AISVC-RES-07] Manages strategy rollback and audit logging.
"""
def __init__(self, max_records: int = 100):
self._records: list[RollbackRecord] = []
self._max_records = max_records
def record_rollback(
self,
from_strategy: StrategyType,
to_strategy: StrategyType,
reason: str,
tenant_id: str | None = None,
request_id: str | None = None,
) -> None:
"""Record a rollback event."""
record = RollbackRecord(
timestamp=time.time(),
from_strategy=from_strategy,
to_strategy=to_strategy,
reason=reason,
tenant_id=tenant_id,
request_id=request_id,
)
self._records.append(record)
if len(self._records) > self._max_records:
self._records = self._records[-self._max_records:]
logger.info(
f"[AC-AISVC-RES-07] Rollback recorded: {from_strategy.value} -> {to_strategy.value}, "
f"reason={reason}, tenant={tenant_id}"
)
def get_recent_rollbacks(self, limit: int = 10) -> list[RollbackRecord]:
"""Get recent rollback records."""
return self._records[-limit:]
def get_rollback_count(self, since_timestamp: float | None = None) -> int:
"""Get count of rollbacks, optionally since a timestamp."""
if since_timestamp is None:
return len(self._records)
return sum(1 for r in self._records if r.timestamp >= since_timestamp)
class DefaultPipeline:
"""
[AC-AISVC-RES-01] Default pipeline that preserves existing online logic.
This pipeline uses the existing OptimizedRetriever without any new features.
"""
def __init__(self):
self._retriever = None
async def execute(
self,
ctx: StrategyContext,
) -> "RetrievalResult":
"""
Execute default retrieval strategy.
Uses existing OptimizedRetriever with current configuration.
"""
from app.services.retrieval.optimized_retriever import get_optimized_retriever
from app.services.retrieval.base import RetrievalContext
if self._retriever is None:
self._retriever = await get_optimized_retriever()
retrieval_ctx = RetrievalContext(
tenant_id=ctx.tenant_id,
query=ctx.query,
metadata_filter=ctx.metadata_filter,
kb_ids=ctx.kb_ids,
)
return await self._retriever.retrieve(retrieval_ctx)
class EnhancedPipeline:
"""
[AC-AISVC-RES-02] Enhanced pipeline with new end-to-end retrieval features.
Features:
- Document preprocessing (cleaning/normalization)
- Structured chunking (markdown/tables/FAQ)
- Metadata generation and mounting
- Embedding strategy (document/query prefix + Matryoshka)
- Metadata inference and filtering (hard/soft filter)
- Retrieval strategy (Dense + Keyword/Hybrid + RRF)
- Optional reranking
"""
def __init__(
self,
config: RoutingConfig | None = None,
):
self._config = config or RoutingConfig()
self._retriever = None
async def execute(
self,
ctx: StrategyContext,
) -> "RetrievalResult":
"""
Execute enhanced retrieval strategy.
Uses OptimizedRetriever with enhanced configuration.
"""
from app.services.retrieval.optimized_retriever import OptimizedRetriever
from app.services.retrieval.base import RetrievalContext
if self._retriever is None:
self._retriever = OptimizedRetriever(
two_stage_enabled=True,
hybrid_enabled=True,
)
retrieval_ctx = RetrievalContext(
tenant_id=ctx.tenant_id,
query=ctx.query,
metadata_filter=ctx.metadata_filter,
kb_ids=ctx.kb_ids,
)
return await self._retriever.retrieve(retrieval_ctx)
class StrategyRouter:
"""
[AC-AISVC-RES-01, AC-AISVC-RES-02, AC-AISVC-RES-03]
Strategy router for retrieval and embedding.
Decision Flow:
1. Check if enhanced strategy is enabled via configuration
2. Check grayscale rules (percentage/allowlist)
3. Route to appropriate pipeline (default/enhanced)
4. Handle rollback on error or performance degradation
Constraints:
- Default strategy MUST preserve existing online logic
- Enhanced strategy MUST be configurable and rollback-able
"""
def __init__(
self,
config: RoutingConfig | None = None,
rollback_manager: RollbackManager | None = None,
):
self._config = config or RoutingConfig()
self._rollback_manager = rollback_manager or RollbackManager()
self._default_pipeline = DefaultPipeline()
self._enhanced_pipeline = EnhancedPipeline(self._config)
self._current_strategy = StrategyType.DEFAULT
self._strategy_enabled = True
@property
def current_strategy(self) -> StrategyType:
"""Get current active strategy."""
return self._current_strategy
@property
def config(self) -> RoutingConfig:
"""Get current configuration."""
return self._config
def update_config(self, new_config: RoutingConfig) -> None:
"""
[AC-AISVC-RES-15] Update routing configuration (hot reload).
Args:
new_config: New configuration to apply
"""
old_strategy = self._config.strategy
self._config = new_config
logger.info(
f"[AC-AISVC-RES-15] Routing config updated: "
f"strategy={new_config.strategy.value}, "
f"mode={new_config.rag_runtime_mode.value}, "
f"grayscale={new_config.grayscale_percentage:.2%}"
)
if old_strategy != new_config.strategy:
logger.info(
f"[AC-AISVC-RES-02] Strategy changed: {old_strategy.value} -> {new_config.strategy.value}"
)
def route(
self,
ctx: StrategyContext,
) -> StrategyResult:
"""
[AC-AISVC-RES-01, AC-AISVC-RES-02] Route to appropriate strategy.
Args:
ctx: Strategy context with tenant, query, metadata, etc.
Returns:
StrategyResult with selected strategy and mode
"""
if not self._strategy_enabled:
logger.info("[AC-AISVC-RES-07] Strategy disabled, using default")
return StrategyResult(
strategy=StrategyType.DEFAULT,
mode=self._config.rag_runtime_mode,
should_fallback=False,
diagnostics={"reason": "strategy_disabled"},
)
use_enhanced = self._config.should_use_enhanced_strategy(ctx.tenant_id)
if use_enhanced:
self._current_strategy = StrategyType.ENHANCED
logger.info(
f"[AC-AISVC-RES-02] Routing to ENHANCED strategy: tenant={ctx.tenant_id}"
)
else:
self._current_strategy = StrategyType.DEFAULT
logger.info(
f"[AC-AISVC-RES-01] Routing to DEFAULT strategy: tenant={ctx.tenant_id}"
)
return StrategyResult(
strategy=self._current_strategy,
mode=self._config.rag_runtime_mode,
diagnostics={
"grayscale_percentage": self._config.grayscale_percentage,
"in_allowlist": ctx.tenant_id in self._config.grayscale_allowlist if ctx.tenant_id else False,
},
)
async def execute(
self,
ctx: StrategyContext,
) -> tuple["RetrievalResult", StrategyResult]:
"""
Execute retrieval with strategy routing.
Args:
ctx: Strategy context
Returns:
Tuple of (RetrievalResult, StrategyResult)
"""
start_time = time.time()
result = self.route(ctx)
try:
if result.strategy == StrategyType.ENHANCED:
retrieval_result = await self._enhanced_pipeline.execute(ctx)
else:
retrieval_result = await self._default_pipeline.execute(ctx)
duration_ms = int((time.time() - start_time) * 1000)
if duration_ms > self._config.performance_budget_ms:
degradation = (duration_ms - self._config.performance_budget_ms) / self._config.performance_budget_ms
if degradation > self._config.performance_degradation_threshold:
logger.warning(
f"[AC-AISVC-RES-08] Performance degradation detected: "
f"duration={duration_ms}ms, budget={self._config.performance_budget_ms}ms, "
f"degradation={degradation:.2%}"
)
return retrieval_result, result
except Exception as e:
logger.error(
f"[AC-AISVC-RES-07] Strategy execution failed: {e}, "
f"strategy={result.strategy.value}"
)
if result.strategy == StrategyType.ENHANCED:
self._rollback_manager.record_rollback(
from_strategy=StrategyType.ENHANCED,
to_strategy=StrategyType.DEFAULT,
reason=str(e),
tenant_id=ctx.tenant_id,
)
logger.info("[AC-AISVC-RES-07] Falling back to DEFAULT strategy")
retrieval_result = await self._default_pipeline.execute(ctx)
return retrieval_result, StrategyResult(
strategy=StrategyType.DEFAULT,
mode=result.mode,
should_fallback=True,
fallback_reason=str(e),
diagnostics=result.diagnostics,
)
raise
def rollback(
self,
reason: str,
tenant_id: str | None = None,
request_id: str | None = None,
) -> None:
"""
[AC-AISVC-RES-07] Force rollback to default strategy.
Args:
reason: Reason for rollback
tenant_id: Optional tenant ID for audit
request_id: Optional request ID for audit
"""
if self._current_strategy == StrategyType.ENHANCED:
self._rollback_manager.record_rollback(
from_strategy=StrategyType.ENHANCED,
to_strategy=StrategyType.DEFAULT,
reason=reason,
tenant_id=tenant_id,
request_id=request_id,
)
self._current_strategy = StrategyType.DEFAULT
self._config.strategy = StrategyType.DEFAULT
logger.info(
f"[AC-AISVC-RES-07] Rollback executed: reason={reason}, tenant={tenant_id}"
)
def get_rollback_records(self, limit: int = 10) -> list[RollbackRecord]:
"""Get recent rollback records."""
return self._rollback_manager.get_recent_rollbacks(limit)
def validate_config(self) -> tuple[bool, list[str]]:
"""
[AC-AISVC-RES-06] Validate current configuration.
Returns:
Tuple of (is_valid, list of error messages)
"""
return self._config.validate()
_strategy_router: StrategyRouter | None = None
def get_strategy_router() -> StrategyRouter:
"""Get or create StrategyRouter singleton."""
global _strategy_router
if _strategy_router is None:
_strategy_router = StrategyRouter()
return _strategy_router
def reset_strategy_router() -> None:
"""Reset StrategyRouter singleton (for testing)."""
global _strategy_router
_strategy_router = None