404 lines
13 KiB
Python
404 lines
13 KiB
Python
|
|
"""
|
||
|
|
Strategy Router for Retrieval and Embedding.
|
||
|
|
[AC-AISVC-RES-01, AC-AISVC-RES-02, AC-AISVC-RES-03] Routes to default or enhanced strategy.
|
||
|
|
|
||
|
|
Key Features:
|
||
|
|
- Default strategy preserves existing online logic
|
||
|
|
- Enhanced strategy is configurable and can be rolled back
|
||
|
|
- Supports grayscale release (percentage/allowlist)
|
||
|
|
- Supports rollback on error or performance degradation
|
||
|
|
"""
|
||
|
|
|
||
|
|
from __future__ import annotations
|
||
|
|
|
||
|
|
import logging
|
||
|
|
import time
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from typing import TYPE_CHECKING, Any
|
||
|
|
|
||
|
|
from app.services.retrieval.routing_config import (
|
||
|
|
RagRuntimeMode,
|
||
|
|
StrategyType,
|
||
|
|
RoutingConfig,
|
||
|
|
StrategyContext,
|
||
|
|
StrategyResult,
|
||
|
|
)
|
||
|
|
|
||
|
|
if TYPE_CHECKING:
|
||
|
|
from app.services.retrieval.base import RetrievalResult
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class RollbackRecord:
|
||
|
|
"""Record for strategy rollback event."""
|
||
|
|
timestamp: float
|
||
|
|
from_strategy: StrategyType
|
||
|
|
to_strategy: StrategyType
|
||
|
|
reason: str
|
||
|
|
tenant_id: str | None = None
|
||
|
|
request_id: str | None = None
|
||
|
|
|
||
|
|
|
||
|
|
class RollbackManager:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-RES-07] Manages strategy rollback and audit logging.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self, max_records: int = 100):
|
||
|
|
self._records: list[RollbackRecord] = []
|
||
|
|
self._max_records = max_records
|
||
|
|
|
||
|
|
def record_rollback(
|
||
|
|
self,
|
||
|
|
from_strategy: StrategyType,
|
||
|
|
to_strategy: StrategyType,
|
||
|
|
reason: str,
|
||
|
|
tenant_id: str | None = None,
|
||
|
|
request_id: str | None = None,
|
||
|
|
) -> None:
|
||
|
|
"""Record a rollback event."""
|
||
|
|
record = RollbackRecord(
|
||
|
|
timestamp=time.time(),
|
||
|
|
from_strategy=from_strategy,
|
||
|
|
to_strategy=to_strategy,
|
||
|
|
reason=reason,
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
request_id=request_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
self._records.append(record)
|
||
|
|
|
||
|
|
if len(self._records) > self._max_records:
|
||
|
|
self._records = self._records[-self._max_records:]
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
f"[AC-AISVC-RES-07] Rollback recorded: {from_strategy.value} -> {to_strategy.value}, "
|
||
|
|
f"reason={reason}, tenant={tenant_id}"
|
||
|
|
)
|
||
|
|
|
||
|
|
def get_recent_rollbacks(self, limit: int = 10) -> list[RollbackRecord]:
|
||
|
|
"""Get recent rollback records."""
|
||
|
|
return self._records[-limit:]
|
||
|
|
|
||
|
|
def get_rollback_count(self, since_timestamp: float | None = None) -> int:
|
||
|
|
"""Get count of rollbacks, optionally since a timestamp."""
|
||
|
|
if since_timestamp is None:
|
||
|
|
return len(self._records)
|
||
|
|
|
||
|
|
return sum(1 for r in self._records if r.timestamp >= since_timestamp)
|
||
|
|
|
||
|
|
|
||
|
|
class DefaultPipeline:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-RES-01] Default pipeline that preserves existing online logic.
|
||
|
|
|
||
|
|
This pipeline uses the existing OptimizedRetriever without any new features.
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self._retriever = None
|
||
|
|
|
||
|
|
async def execute(
|
||
|
|
self,
|
||
|
|
ctx: StrategyContext,
|
||
|
|
) -> "RetrievalResult":
|
||
|
|
"""
|
||
|
|
Execute default retrieval strategy.
|
||
|
|
|
||
|
|
Uses existing OptimizedRetriever with current configuration.
|
||
|
|
"""
|
||
|
|
from app.services.retrieval.optimized_retriever import get_optimized_retriever
|
||
|
|
from app.services.retrieval.base import RetrievalContext
|
||
|
|
|
||
|
|
if self._retriever is None:
|
||
|
|
self._retriever = await get_optimized_retriever()
|
||
|
|
|
||
|
|
retrieval_ctx = RetrievalContext(
|
||
|
|
tenant_id=ctx.tenant_id,
|
||
|
|
query=ctx.query,
|
||
|
|
metadata_filter=ctx.metadata_filter,
|
||
|
|
kb_ids=ctx.kb_ids,
|
||
|
|
)
|
||
|
|
|
||
|
|
return await self._retriever.retrieve(retrieval_ctx)
|
||
|
|
|
||
|
|
|
||
|
|
class EnhancedPipeline:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-RES-02] Enhanced pipeline with new end-to-end retrieval features.
|
||
|
|
|
||
|
|
Features:
|
||
|
|
- Document preprocessing (cleaning/normalization)
|
||
|
|
- Structured chunking (markdown/tables/FAQ)
|
||
|
|
- Metadata generation and mounting
|
||
|
|
- Embedding strategy (document/query prefix + Matryoshka)
|
||
|
|
- Metadata inference and filtering (hard/soft filter)
|
||
|
|
- Retrieval strategy (Dense + Keyword/Hybrid + RRF)
|
||
|
|
- Optional reranking
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
config: RoutingConfig | None = None,
|
||
|
|
):
|
||
|
|
self._config = config or RoutingConfig()
|
||
|
|
self._retriever = None
|
||
|
|
|
||
|
|
async def execute(
|
||
|
|
self,
|
||
|
|
ctx: StrategyContext,
|
||
|
|
) -> "RetrievalResult":
|
||
|
|
"""
|
||
|
|
Execute enhanced retrieval strategy.
|
||
|
|
|
||
|
|
Uses OptimizedRetriever with enhanced configuration.
|
||
|
|
"""
|
||
|
|
from app.services.retrieval.optimized_retriever import OptimizedRetriever
|
||
|
|
from app.services.retrieval.base import RetrievalContext
|
||
|
|
|
||
|
|
if self._retriever is None:
|
||
|
|
self._retriever = OptimizedRetriever(
|
||
|
|
two_stage_enabled=True,
|
||
|
|
hybrid_enabled=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
retrieval_ctx = RetrievalContext(
|
||
|
|
tenant_id=ctx.tenant_id,
|
||
|
|
query=ctx.query,
|
||
|
|
metadata_filter=ctx.metadata_filter,
|
||
|
|
kb_ids=ctx.kb_ids,
|
||
|
|
)
|
||
|
|
|
||
|
|
return await self._retriever.retrieve(retrieval_ctx)
|
||
|
|
|
||
|
|
|
||
|
|
class StrategyRouter:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-RES-01, AC-AISVC-RES-02, AC-AISVC-RES-03]
|
||
|
|
Strategy router for retrieval and embedding.
|
||
|
|
|
||
|
|
Decision Flow:
|
||
|
|
1. Check if enhanced strategy is enabled via configuration
|
||
|
|
2. Check grayscale rules (percentage/allowlist)
|
||
|
|
3. Route to appropriate pipeline (default/enhanced)
|
||
|
|
4. Handle rollback on error or performance degradation
|
||
|
|
|
||
|
|
Constraints:
|
||
|
|
- Default strategy MUST preserve existing online logic
|
||
|
|
- Enhanced strategy MUST be configurable and rollback-able
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(
|
||
|
|
self,
|
||
|
|
config: RoutingConfig | None = None,
|
||
|
|
rollback_manager: RollbackManager | None = None,
|
||
|
|
):
|
||
|
|
self._config = config or RoutingConfig()
|
||
|
|
self._rollback_manager = rollback_manager or RollbackManager()
|
||
|
|
self._default_pipeline = DefaultPipeline()
|
||
|
|
self._enhanced_pipeline = EnhancedPipeline(self._config)
|
||
|
|
|
||
|
|
self._current_strategy = StrategyType.DEFAULT
|
||
|
|
self._strategy_enabled = True
|
||
|
|
|
||
|
|
@property
|
||
|
|
def current_strategy(self) -> StrategyType:
|
||
|
|
"""Get current active strategy."""
|
||
|
|
return self._current_strategy
|
||
|
|
|
||
|
|
@property
|
||
|
|
def config(self) -> RoutingConfig:
|
||
|
|
"""Get current configuration."""
|
||
|
|
return self._config
|
||
|
|
|
||
|
|
def update_config(self, new_config: RoutingConfig) -> None:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-RES-15] Update routing configuration (hot reload).
|
||
|
|
|
||
|
|
Args:
|
||
|
|
new_config: New configuration to apply
|
||
|
|
"""
|
||
|
|
old_strategy = self._config.strategy
|
||
|
|
self._config = new_config
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
f"[AC-AISVC-RES-15] Routing config updated: "
|
||
|
|
f"strategy={new_config.strategy.value}, "
|
||
|
|
f"mode={new_config.rag_runtime_mode.value}, "
|
||
|
|
f"grayscale={new_config.grayscale_percentage:.2%}"
|
||
|
|
)
|
||
|
|
|
||
|
|
if old_strategy != new_config.strategy:
|
||
|
|
logger.info(
|
||
|
|
f"[AC-AISVC-RES-02] Strategy changed: {old_strategy.value} -> {new_config.strategy.value}"
|
||
|
|
)
|
||
|
|
|
||
|
|
def route(
|
||
|
|
self,
|
||
|
|
ctx: StrategyContext,
|
||
|
|
) -> StrategyResult:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-RES-01, AC-AISVC-RES-02] Route to appropriate strategy.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
ctx: Strategy context with tenant, query, metadata, etc.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
StrategyResult with selected strategy and mode
|
||
|
|
"""
|
||
|
|
if not self._strategy_enabled:
|
||
|
|
logger.info("[AC-AISVC-RES-07] Strategy disabled, using default")
|
||
|
|
return StrategyResult(
|
||
|
|
strategy=StrategyType.DEFAULT,
|
||
|
|
mode=self._config.rag_runtime_mode,
|
||
|
|
should_fallback=False,
|
||
|
|
diagnostics={"reason": "strategy_disabled"},
|
||
|
|
)
|
||
|
|
|
||
|
|
use_enhanced = self._config.should_use_enhanced_strategy(ctx.tenant_id)
|
||
|
|
|
||
|
|
if use_enhanced:
|
||
|
|
self._current_strategy = StrategyType.ENHANCED
|
||
|
|
logger.info(
|
||
|
|
f"[AC-AISVC-RES-02] Routing to ENHANCED strategy: tenant={ctx.tenant_id}"
|
||
|
|
)
|
||
|
|
else:
|
||
|
|
self._current_strategy = StrategyType.DEFAULT
|
||
|
|
logger.info(
|
||
|
|
f"[AC-AISVC-RES-01] Routing to DEFAULT strategy: tenant={ctx.tenant_id}"
|
||
|
|
)
|
||
|
|
|
||
|
|
return StrategyResult(
|
||
|
|
strategy=self._current_strategy,
|
||
|
|
mode=self._config.rag_runtime_mode,
|
||
|
|
diagnostics={
|
||
|
|
"grayscale_percentage": self._config.grayscale_percentage,
|
||
|
|
"in_allowlist": ctx.tenant_id in self._config.grayscale_allowlist if ctx.tenant_id else False,
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
async def execute(
|
||
|
|
self,
|
||
|
|
ctx: StrategyContext,
|
||
|
|
) -> tuple["RetrievalResult", StrategyResult]:
|
||
|
|
"""
|
||
|
|
Execute retrieval with strategy routing.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
ctx: Strategy context
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Tuple of (RetrievalResult, StrategyResult)
|
||
|
|
"""
|
||
|
|
start_time = time.time()
|
||
|
|
result = self.route(ctx)
|
||
|
|
|
||
|
|
try:
|
||
|
|
if result.strategy == StrategyType.ENHANCED:
|
||
|
|
retrieval_result = await self._enhanced_pipeline.execute(ctx)
|
||
|
|
else:
|
||
|
|
retrieval_result = await self._default_pipeline.execute(ctx)
|
||
|
|
|
||
|
|
duration_ms = int((time.time() - start_time) * 1000)
|
||
|
|
|
||
|
|
if duration_ms > self._config.performance_budget_ms:
|
||
|
|
degradation = (duration_ms - self._config.performance_budget_ms) / self._config.performance_budget_ms
|
||
|
|
if degradation > self._config.performance_degradation_threshold:
|
||
|
|
logger.warning(
|
||
|
|
f"[AC-AISVC-RES-08] Performance degradation detected: "
|
||
|
|
f"duration={duration_ms}ms, budget={self._config.performance_budget_ms}ms, "
|
||
|
|
f"degradation={degradation:.2%}"
|
||
|
|
)
|
||
|
|
|
||
|
|
return retrieval_result, result
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(
|
||
|
|
f"[AC-AISVC-RES-07] Strategy execution failed: {e}, "
|
||
|
|
f"strategy={result.strategy.value}"
|
||
|
|
)
|
||
|
|
|
||
|
|
if result.strategy == StrategyType.ENHANCED:
|
||
|
|
self._rollback_manager.record_rollback(
|
||
|
|
from_strategy=StrategyType.ENHANCED,
|
||
|
|
to_strategy=StrategyType.DEFAULT,
|
||
|
|
reason=str(e),
|
||
|
|
tenant_id=ctx.tenant_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
logger.info("[AC-AISVC-RES-07] Falling back to DEFAULT strategy")
|
||
|
|
|
||
|
|
retrieval_result = await self._default_pipeline.execute(ctx)
|
||
|
|
|
||
|
|
return retrieval_result, StrategyResult(
|
||
|
|
strategy=StrategyType.DEFAULT,
|
||
|
|
mode=result.mode,
|
||
|
|
should_fallback=True,
|
||
|
|
fallback_reason=str(e),
|
||
|
|
diagnostics=result.diagnostics,
|
||
|
|
)
|
||
|
|
|
||
|
|
raise
|
||
|
|
|
||
|
|
def rollback(
|
||
|
|
self,
|
||
|
|
reason: str,
|
||
|
|
tenant_id: str | None = None,
|
||
|
|
request_id: str | None = None,
|
||
|
|
) -> None:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-RES-07] Force rollback to default strategy.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
reason: Reason for rollback
|
||
|
|
tenant_id: Optional tenant ID for audit
|
||
|
|
request_id: Optional request ID for audit
|
||
|
|
"""
|
||
|
|
if self._current_strategy == StrategyType.ENHANCED:
|
||
|
|
self._rollback_manager.record_rollback(
|
||
|
|
from_strategy=StrategyType.ENHANCED,
|
||
|
|
to_strategy=StrategyType.DEFAULT,
|
||
|
|
reason=reason,
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
request_id=request_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
self._current_strategy = StrategyType.DEFAULT
|
||
|
|
self._config.strategy = StrategyType.DEFAULT
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
f"[AC-AISVC-RES-07] Rollback executed: reason={reason}, tenant={tenant_id}"
|
||
|
|
)
|
||
|
|
|
||
|
|
def get_rollback_records(self, limit: int = 10) -> list[RollbackRecord]:
|
||
|
|
"""Get recent rollback records."""
|
||
|
|
return self._rollback_manager.get_recent_rollbacks(limit)
|
||
|
|
|
||
|
|
def validate_config(self) -> tuple[bool, list[str]]:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-RES-06] Validate current configuration.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Tuple of (is_valid, list of error messages)
|
||
|
|
"""
|
||
|
|
return self._config.validate()
|
||
|
|
|
||
|
|
|
||
|
|
_strategy_router: StrategyRouter | None = None
|
||
|
|
|
||
|
|
|
||
|
|
def get_strategy_router() -> StrategyRouter:
|
||
|
|
"""Get or create StrategyRouter singleton."""
|
||
|
|
global _strategy_router
|
||
|
|
if _strategy_router is None:
|
||
|
|
_strategy_router = StrategyRouter()
|
||
|
|
return _strategy_router
|
||
|
|
|
||
|
|
|
||
|
|
def reset_strategy_router() -> None:
|
||
|
|
"""Reset StrategyRouter singleton (for testing)."""
|
||
|
|
global _strategy_router
|
||
|
|
_strategy_router = None
|