From b3343f9e523caba58f61cdf18b9390b8cc003ca6 Mon Sep 17 00:00:00 2001 From: MerCry Date: Wed, 11 Mar 2026 19:02:40 +0800 Subject: [PATCH] =?UTF-8?q?[AC-RETRIEVAL-STRATEGY]=20feat(retrieval):=20?= =?UTF-8?q?=E6=96=B0=E5=A2=9E=E6=A3=80=E7=B4=A2=E7=AD=96=E7=95=A5=E8=B7=AF?= =?UTF-8?q?=E7=94=B1=E6=9C=8D=E5=8A=A1?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增 strategy_service 实现检索策略路由核心逻辑 - 新增 strategy_metrics 提供策略性能指标收集 - 新增 strategy_audit 提供策略审计日志功能 - 新增 retrieval_strategy API 端点支持策略管理 - 支持多种检索策略的动态切换和监控 --- .../app/api/admin/retrieval_strategy.py | 349 +++++++++++++ .../app/services/retrieval/strategy_audit.py | 300 +++++++++++ .../services/retrieval/strategy_metrics.py | 452 ++++++++++++++++ .../services/retrieval/strategy_service.py | 484 ++++++++++++++++++ 4 files changed, 1585 insertions(+) create mode 100644 ai-service/app/api/admin/retrieval_strategy.py create mode 100644 ai-service/app/services/retrieval/strategy_audit.py create mode 100644 ai-service/app/services/retrieval/strategy_metrics.py create mode 100644 ai-service/app/services/retrieval/strategy_service.py diff --git a/ai-service/app/api/admin/retrieval_strategy.py b/ai-service/app/api/admin/retrieval_strategy.py new file mode 100644 index 0000000..7d72939 --- /dev/null +++ b/ai-service/app/api/admin/retrieval_strategy.py @@ -0,0 +1,349 @@ +""" +Retrieval Strategy API Endpoints. +[AC-AISVC-RES-01~15] 策略管理 API。 + +Endpoints: +- GET /strategy/retrieval/current - 获取当前策略状态 +- POST /strategy/retrieval/switch - 切换策略 +- POST /strategy/retrieval/validate - 验证策略配置 +- POST /strategy/retrieval/rollback - 回退策略 +""" + +import logging +from typing import Any + +from fastapi import APIRouter, HTTPException, status +from pydantic import BaseModel, Field + +from app.services.retrieval.strategy.config import ( + GrayscaleConfig, + ModeRouterConfig, + PipelineConfig, + RetrievalStrategyConfig, + RerankerConfig, + RuntimeMode, + StrategyType, +) +from app.services.retrieval.strategy.rollback_manager import ( + RollbackManager, + RollbackResult, + RollbackTrigger, + get_rollback_manager, +) +from app.services.retrieval.strategy.strategy_router import ( + StrategyRouter, + get_strategy_router, + set_strategy_router, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/strategy/retrieval", tags=["strategy"]) + + +class GrayscaleConfigSchema(BaseModel): + """灰度配置 Schema。""" + enabled: bool = False + percentage: float = Field(default=0.0, ge=0.0, le=100.0) + allowlist: list[str] = Field(default_factory=list) + + +class RerankerConfigSchema(BaseModel): + """重排器配置 Schema。""" + enabled: bool = False + model: str = "cross-encoder" + top_k_after_rerank: int = 5 + min_score_threshold: float = 0.3 + + +class ModeRouterConfigSchema(BaseModel): + """模式路由配置 Schema。""" + runtime_mode: str = "direct" + react_trigger_confidence_threshold: float = 0.6 + react_trigger_complexity_score: float = 0.5 + react_max_steps: int = 5 + direct_fallback_on_low_confidence: bool = True + + +class PipelineConfigSchema(BaseModel): + """Pipeline 配置 Schema。""" + top_k: int = 5 + score_threshold: float = 0.01 + min_hits: int = 1 + two_stage_enabled: bool = True + + +class StrategyStatusResponse(BaseModel): + """策略状态响应。""" + active_strategy: str + grayscale: GrayscaleConfigSchema + pipeline: PipelineConfigSchema + reranker: RerankerConfigSchema + mode_router: ModeRouterConfigSchema + performance_thresholds: dict[str, float] = Field(default_factory=dict) + + +class SwitchRequest(BaseModel): + """策略切换请求。""" + active_strategy: str | None = None + grayscale: GrayscaleConfigSchema | None = None + mode_router: ModeRouterConfigSchema | None = None + reranker: RerankerConfigSchema | None = None + + +class SwitchResponse(BaseModel): + """策略切换响应。""" + success: bool + previous_strategy: str + current_strategy: str + message: str + + +class ValidationRequest(BaseModel): + """策略验证请求。""" + config: dict[str, Any] = Field(default_factory=dict) + + +class ValidationResponse(BaseModel): + """策略验证响应。""" + valid: bool + errors: list[str] = Field(default_factory=list) + warnings: list[str] = Field(default_factory=list) + config_summary: dict[str, Any] = Field(default_factory=dict) + + +class RollbackResponse(BaseModel): + """策略回退响应。""" + success: bool + previous_strategy: str + current_strategy: str + trigger: str + reason: str + audit_log_id: str | None = None + + +@router.get( + "/current", + response_model=StrategyStatusResponse, + summary="获取当前策略状态", + description="【AC-AISVC-RES-01】 获取当前活跃的检索策略配置状态。", +) +async def get_current_strategy() -> StrategyStatusResponse: + """ + 【AC-AISVC-RES-01】 获取当前策略状态。 + """ + strategy_router = get_strategy_router() + config = strategy_router.get_config() + + return StrategyStatusResponse( + active_strategy=config.active_strategy.value, + grayscale=GrayscaleConfigSchema( + enabled=config.grayscale.enabled, + percentage=config.grayscale.percentage, + allowlist=config.grayscale.allowlist, + ), + pipeline=PipelineConfigSchema( + top_k=config.pipeline.top_k, + score_threshold=config.pipeline.score_threshold, + min_hits=config.pipeline.min_hits, + two_stage_enabled=config.pipeline.two_stage_enabled, + ), + reranker=RerankerConfigSchema( + enabled=config.reranker.enabled, + model=config.reranker.model, + top_k_after_rerank=config.reranker.top_k_after_rerank, + min_score_threshold=config.reranker.min_score_threshold, + ), + mode_router=ModeRouterConfigSchema( + runtime_mode=config.mode_router.runtime_mode.value, + react_trigger_confidence_threshold=config.mode_router.react_trigger_confidence_threshold, + react_trigger_complexity_score=config.mode_router.react_trigger_complexity_score, + react_max_steps=config.mode_router.react_max_steps, + direct_fallback_on_low_confidence=config.mode_router.direct_fallback_on_low_confidence, + ), + performance_thresholds=config.performance_thresholds, + ) + + +@router.post( + "/switch", + response_model=SwitchResponse, + summary="切换策略", + description="【AC-AISVC-RES-02, AC-AISVC-RES-03】 切换检索策略, 支持灰度发布配置。", +) +async def switch_strategy(request: SwitchRequest) -> SwitchResponse: + """ + 【AC-AISVC-RES-02, AC-AISVC-RES-03】 切换策略。 + + 支持灰度发布配置(percentage/allowlist)。 + """ + strategy_router = get_strategy_router() + current_config = strategy_router.get_config() + previous_strategy = current_config.active_strategy.value + + try: + new_active_strategy = StrategyType(request.active_strategy) if request.active_strategy else current_config.active_strategy + except ValueError as e: + raise HTTPException( + status_code=400, + detail=f"Invalid strategy type: {request.active_strategy}. Valid values are: default, enhanced", + ) + + new_config = RetrievalStrategyConfig( + active_strategy=new_active_strategy, + grayscale=GrayscaleConfig( + enabled=request.grayscale.enabled if request.grayscale else current_config.grayscale.enabled, + percentage=request.grayscale.percentage if request.grayscale else current_config.grayscale.percentage, + allowlist=request.grayscale.allowlist if request.grayscale else current_config.grayscale.allowlist, + ), + mode_router=ModeRouterConfig( + runtime_mode=RuntimeMode(request.mode_router.runtime_mode) if request.mode_router and request.mode_router.runtime_mode else current_config.mode_router.runtime_mode, + react_trigger_confidence_threshold=request.mode_router.react_trigger_confidence_threshold if request.mode_router else current_config.mode_router.react_trigger_confidence_threshold, + react_trigger_complexity_score=request.mode_router.react_trigger_complexity_score if request.mode_router else current_config.mode_router.react_trigger_complexity_score, + react_max_steps=request.mode_router.react_max_steps if request.mode_router else current_config.mode_router.react_max_steps, + direct_fallback_on_low_confidence=request.mode_router.direct_fallback_on_low_confidence if request.mode_router else current_config.mode_router.direct_fallback_on_low_confidence, + ) if request.mode_router else current_config.mode_router, + reranker=RerankerConfig( + enabled=request.reranker.enabled if request.reranker else current_config.reranker.enabled, + model=request.reranker.model if request.reranker else current_config.reranker.model, + top_k_after_rerank=request.reranker.top_k_after_rerank if request.reranker else current_config.reranker.top_k_after_rerank, + min_score_threshold=request.reranker.min_score_threshold if request.reranker else current_config.reranker.min_score_threshold, + ) if request.reranker else current_config.reranker, + pipeline=current_config.pipeline, + metadata_inference=current_config.metadata_inference, + performance_thresholds=current_config.performance_thresholds, + ) + + strategy_router.update_config(new_config) + + logger.info( + f"[AC-AISVC-RES-02] Strategy switched: {previous_strategy} -> {new_config.active_strategy.value}" + ) + + return SwitchResponse( + success=True, + previous_strategy=previous_strategy, + current_strategy=new_config.active_strategy.value, + message=f"Strategy switched from {previous_strategy} to {new_config.active_strategy.value}", + ) + + +@router.post( + "/validate", + response_model=ValidationResponse, + summary="验证策略配置", + description="【AC-AISVC-RES-06, AC-AISVC-RES-08】 验证策略配置的完整性与一致性。", +) +async def validate_strategy(request: ValidationRequest) -> ValidationResponse: + """ + 【AC-AISVC-RES-06, AC-AISVC-RES-08】 验证策略配置。 + """ + errors: list[str] = [] + warnings: list[str] = [] + config = request.config + + if "active_strategy" in config: + if config["active_strategy"] not in ["default", "enhanced"]: + errors.append(f"Invalid active_strategy: {config['active_strategy']}") + + if "grayscale" in config: + grayscale = config["grayscale"] + if grayscale.get("percentage", 1.0) is not None: + if not (0 <= grayscale["percentage"] <= 100): + errors.append(f"Invalid grayscale percentage: {grayscale['percentage']}") + + if "mode_router" in config: + mode_router = config["mode_router"] + if "runtime_mode" in mode_router: + if mode_router["runtime_mode"] not in ["direct", "react", "auto"]: + errors.append(f"Invalid runtime_mode: {mode_router['runtime_mode']}") + + if "reranker" in config: + reranker = config["reranker"] + if reranker.get("enabled") and reranker.get("top_k_after_rerank", 1) > 20: + warnings.append(f"top_k_after_rerank should be between 1 and 20, current: {reranker['top_k_after_rerank']}") + + if "performance_thresholds" in config: + thresholds = config["performance_thresholds"] + if thresholds.get("max_latency_ms", 0) is not None and thresholds["max_latency_ms"] < 100: + warnings.append(f"max_latency_ms seems too low: {thresholds['max_latency_ms']}ms") + + return ValidationResponse( + valid=len(errors) == 0, + errors=errors, + warnings=warnings, + config_summary={ + "active_strategy": config.get("active_strategy", "default"), + "grayscale_enabled": config.get("grayscale", {}).get("enabled", False), + "runtime_mode": config.get("mode_router", {}).get("runtime_mode", "direct"), + "reranker_enabled": config.get("reranker", {}).get("enabled", False), + "performance_thresholds": config.get("performance_thresholds", {}), + }, + ) + + +@router.post( + "/rollback", + response_model=RollbackResponse, + summary="回退策略", + description="【AC-AISVC-RES-07】 回退到默认策略。", +) +async def rollback_strategy( + trigger: str = "manual", + reason: str = "", +) -> RollbackResponse: + """ + 【AC-AISVC-RES-07】 回退策略。 + + 支持手动触发和自动触发(性能退化、异常)。 + """ + strategy_router = get_strategy_router() + current_config = strategy_router.get_config() + previous_strategy = current_config.active_strategy + + if previous_strategy == StrategyType.DEFAULT: + return RollbackResponse( + success=False, + previous_strategy=previous_strategy.value, + current_strategy=previous_strategy.value, + trigger=trigger, + reason="Already on default strategy", + audit_log_id=None, + ) + + new_config = RetrievalStrategyConfig( + active_strategy=StrategyType.DEFAULT, + grayscale=current_config.grayscale, + mode_router=current_config.mode_router, + reranker=current_config.reranker, + pipeline=current_config.pipeline, + metadata_inference=current_config.metadata_inference, + performance_thresholds=current_config.performance_thresholds, + ) + strategy_router.update_config(new_config) + + rollback_manager = get_rollback_manager() + rollback_manager.update_config(new_config) + + audit_log = rollback_manager.record_audit( + action="rollback", + details={ + "from_strategy": previous_strategy.value, + "to_strategy": StrategyType.DEFAULT.value, + "trigger": trigger, + "reason": reason or "Manual rollback", + }, + ) + + logger.info( + f"[AC-AISVC-RES-07] Strategy rolled back: {previous_strategy.value} -> default" + ) + + return RollbackResponse( + success=True, + previous_strategy=previous_strategy.value, + current_strategy=StrategyType.DEFAULT.value, + trigger=trigger, + reason=reason or "Manual rollback", + audit_log_id=str(audit_log.timestamp) if audit_log else None, + ) diff --git a/ai-service/app/services/retrieval/strategy_audit.py b/ai-service/app/services/retrieval/strategy_audit.py new file mode 100644 index 0000000..071731b --- /dev/null +++ b/ai-service/app/services/retrieval/strategy_audit.py @@ -0,0 +1,300 @@ +""" +Strategy Audit Service for AI Service. +[AC-AISVC-RES-07] Audit logging for strategy operations. +""" + +import json +import logging +from collections import deque +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +from app.schemas.retrieval_strategy import StrategyAuditLog + +logger = logging.getLogger(__name__) + + +@dataclass +class AuditEntry: + """ + Internal audit entry structure. + """ + + timestamp: str + operation: str + previous_strategy: str | None = None + new_strategy: str | None = None + previous_react_mode: str | None = None + new_react_mode: str | None = None + reason: str | None = None + operator: str | None = None + tenant_id: str | None = None + metadata: dict[str, Any] | None = None + + +class StrategyAuditService: + """ + [AC-AISVC-RES-07] Audit service for strategy operations. + + Features: + - Structured audit logging + - In-memory audit trail (configurable retention) + - JSON output for log aggregation + """ + + def __init__(self, max_entries: int = 1000): + self._audit_log: deque[AuditEntry] = deque(maxlen=max_entries) + self._max_entries = max_entries + + def log( + self, + operation: str, + previous_strategy: str | None = None, + new_strategy: str | None = None, + previous_react_mode: str | None = None, + new_react_mode: str | None = None, + reason: str | None = None, + operator: str | None = None, + tenant_id: str | None = None, + metadata: dict[str, Any] | None = None, + ) -> None: + """ + [AC-AISVC-RES-07] Log a strategy operation. + + Args: + operation: Operation type (switch, rollback, validate). + previous_strategy: Previous strategy value. + new_strategy: New strategy value. + previous_react_mode: Previous react mode. + new_react_mode: New react mode. + reason: Reason for the operation. + operator: Operator who performed the operation. + tenant_id: Tenant ID if applicable. + metadata: Additional metadata. + """ + entry = AuditEntry( + timestamp=datetime.utcnow().isoformat(), + operation=operation, + previous_strategy=previous_strategy, + new_strategy=new_strategy, + previous_react_mode=previous_react_mode, + new_react_mode=new_react_mode, + reason=reason, + operator=operator, + tenant_id=tenant_id, + metadata=metadata, + ) + + self._audit_log.append(entry) + + log_data = { + "audit_type": "strategy_operation", + "timestamp": entry.timestamp, + "operation": entry.operation, + "previous_strategy": entry.previous_strategy, + "new_strategy": entry.new_strategy, + "previous_react_mode": entry.previous_react_mode, + "new_react_mode": entry.new_react_mode, + "reason": entry.reason, + "operator": entry.operator, + "tenant_id": entry.tenant_id, + "metadata": entry.metadata, + } + + logger.info( + f"[AC-AISVC-RES-07] Strategy audit: operation={operation}, " + f"from={previous_strategy} -> to={new_strategy}, " + f"operator={operator}, reason={reason}" + ) + + audit_logger = logging.getLogger("audit.strategy") + audit_logger.info(json.dumps(log_data, ensure_ascii=False)) + + def log_switch( + self, + previous_strategy: str, + new_strategy: str, + previous_react_mode: str | None = None, + new_react_mode: str | None = None, + reason: str | None = None, + operator: str | None = None, + tenant_id: str | None = None, + rollout_config: dict[str, Any] | None = None, + ) -> None: + """ + Log a strategy switch operation. + + Args: + previous_strategy: Previous strategy value. + new_strategy: New strategy value. + previous_react_mode: Previous react mode. + new_react_mode: New react mode. + reason: Reason for the switch. + operator: Operator who performed the switch. + tenant_id: Tenant ID if applicable. + rollout_config: Rollout configuration. + """ + self.log( + operation="switch", + previous_strategy=previous_strategy, + new_strategy=new_strategy, + previous_react_mode=previous_react_mode, + new_react_mode=new_react_mode, + reason=reason, + operator=operator, + tenant_id=tenant_id, + metadata={"rollout_config": rollout_config} if rollout_config else None, + ) + + def log_rollback( + self, + previous_strategy: str, + new_strategy: str, + previous_react_mode: str | None = None, + new_react_mode: str | None = None, + reason: str | None = None, + operator: str | None = None, + tenant_id: str | None = None, + ) -> None: + """ + Log a strategy rollback operation. + + Args: + previous_strategy: Previous strategy value. + new_strategy: Strategy rolled back to. + previous_react_mode: Previous react mode. + new_react_mode: React mode rolled back to. + reason: Reason for the rollback. + operator: Operator who performed the rollback. + tenant_id: Tenant ID if applicable. + """ + self.log( + operation="rollback", + previous_strategy=previous_strategy, + new_strategy=new_strategy, + previous_react_mode=previous_react_mode, + new_react_mode=new_react_mode, + reason=reason or "Manual rollback", + operator=operator, + tenant_id=tenant_id, + ) + + def log_validation( + self, + strategy: str, + react_mode: str | None = None, + checks: list[str] | None = None, + passed: bool = False, + operator: str | None = None, + tenant_id: str | None = None, + ) -> None: + """ + Log a strategy validation operation. + + Args: + strategy: Strategy being validated. + react_mode: React mode being validated. + checks: List of checks performed. + passed: Whether validation passed. + operator: Operator who performed the validation. + tenant_id: Tenant ID if applicable. + """ + self.log( + operation="validate", + new_strategy=strategy, + new_react_mode=react_mode, + operator=operator, + tenant_id=tenant_id, + metadata={ + "checks": checks, + "passed": passed, + }, + ) + + def get_audit_log( + self, + limit: int = 100, + operation: str | None = None, + tenant_id: str | None = None, + ) -> list[StrategyAuditLog]: + """ + Get audit log entries. + + Args: + limit: Maximum number of entries to return. + operation: Filter by operation type. + tenant_id: Filter by tenant ID. + + Returns: + List of StrategyAuditLog entries. + """ + entries = list(self._audit_log) + + if operation: + entries = [e for e in entries if e.operation == operation] + + if tenant_id: + entries = [e for e in entries if e.tenant_id == tenant_id] + + entries = entries[-limit:] + + return [ + StrategyAuditLog( + timestamp=e.timestamp, + operation=e.operation, + previous_strategy=e.previous_strategy, + new_strategy=e.new_strategy, + previous_react_mode=e.previous_react_mode, + new_react_mode=e.new_react_mode, + reason=e.reason, + operator=e.operator, + tenant_id=e.tenant_id, + metadata=e.metadata, + ) + for e in entries + ] + + def get_audit_stats(self) -> dict[str, Any]: + """ + Get audit log statistics. + + Returns: + Dictionary with audit statistics. + """ + entries = list(self._audit_log) + + operation_counts: dict[str, int] = {} + for entry in entries: + operation_counts[entry.operation] = operation_counts.get(entry.operation, 0) + 1 + + return { + "total_entries": len(entries), + "max_entries": self._max_entries, + "operation_counts": operation_counts, + "oldest_entry": entries[0].timestamp if entries else None, + "newest_entry": entries[-1].timestamp if entries else None, + } + + def clear_audit_log(self) -> int: + """ + Clear all audit log entries. + + Returns: + Number of entries cleared. + """ + count = len(self._audit_log) + self._audit_log.clear() + logger.info(f"[AC-AISVC-RES-07] Audit log cleared: {count} entries removed") + return count + + +_audit_service: StrategyAuditService | None = None + + +def get_audit_service() -> StrategyAuditService: + """Get or create StrategyAuditService instance.""" + global _audit_service + if _audit_service is None: + _audit_service = StrategyAuditService() + return _audit_service diff --git a/ai-service/app/services/retrieval/strategy_metrics.py b/ai-service/app/services/retrieval/strategy_metrics.py new file mode 100644 index 0000000..d1c114d --- /dev/null +++ b/ai-service/app/services/retrieval/strategy_metrics.py @@ -0,0 +1,452 @@ +""" +Strategy Metrics Service for AI Service. +[AC-AISVC-RES-03, AC-AISVC-RES-08] Metrics collection for strategy operations. +""" + +import json +import logging +import time +from collections import defaultdict +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +from app.schemas.retrieval_strategy import ( + ReactMode, + StrategyMetrics, + StrategyType, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class LatencyTracker: + """ + Latency tracking for a single operation. + """ + + latencies: list[float] = field(default_factory=list) + max_samples: int = 1000 + + def record(self, latency_ms: float) -> None: + """Record a latency sample.""" + if len(self.latencies) >= self.max_samples: + self.latencies = self.latencies[-self.max_samples // 2 :] + self.latencies.append(latency_ms) + + def get_percentile(self, percentile: float) -> float: + """Get latency at given percentile.""" + if not self.latencies: + return 0.0 + sorted_latencies = sorted(self.latencies) + index = int(len(sorted_latencies) * percentile / 100) + index = min(index, len(sorted_latencies) - 1) + return sorted_latencies[index] + + def get_avg(self) -> float: + """Get average latency.""" + if not self.latencies: + return 0.0 + return sum(self.latencies) / len(self.latencies) + + +@dataclass +class StrategyMetricsData: + """ + Internal metrics data structure. + """ + + total_requests: int = 0 + successful_requests: int = 0 + failed_requests: int = 0 + latency_tracker: LatencyTracker = field(default_factory=LatencyTracker) + direct_route_count: int = 0 + react_route_count: int = 0 + auto_route_count: int = 0 + fallback_count: int = 0 + last_updated: str | None = None + + +class StrategyMetricsService: + """ + [AC-AISVC-RES-03, AC-AISVC-RES-08] Metrics service for strategy operations. + + Features: + - Request counting by strategy and route mode + - Latency tracking with percentiles + - Fallback and error tracking + - Metrics export for monitoring + """ + + def __init__(self): + self._metrics: dict[str, StrategyMetricsData] = defaultdict(StrategyMetricsData) + self._route_metrics: dict[str, StrategyMetricsData] = defaultdict(StrategyMetricsData) + self._current_strategy: StrategyType = StrategyType.DEFAULT + self._current_react_mode: ReactMode = ReactMode.NON_REACT + + def set_current_strategy( + self, + strategy: StrategyType, + react_mode: ReactMode, + ) -> None: + """ + Set current strategy for metrics attribution. + + Args: + strategy: Current active strategy. + react_mode: Current react mode. + """ + self._current_strategy = strategy + self._current_react_mode = react_mode + + def record_request( + self, + latency_ms: float, + success: bool = True, + route_mode: str | None = None, + fallback: bool = False, + strategy: StrategyType | None = None, + ) -> None: + """ + [AC-AISVC-RES-03, AC-AISVC-RES-08] Record a retrieval request. + + Args: + latency_ms: Request latency in milliseconds. + success: Whether the request was successful. + route_mode: Route mode used (direct, react, auto). + fallback: Whether fallback to default occurred. + strategy: Strategy used (defaults to current). + """ + effective_strategy = strategy or self._current_strategy + key = effective_strategy.value + + metrics = self._metrics[key] + metrics.total_requests += 1 + + if success: + metrics.successful_requests += 1 + else: + metrics.failed_requests += 1 + + metrics.latency_tracker.record(latency_ms) + metrics.last_updated = datetime.utcnow().isoformat() + + if fallback: + metrics.fallback_count += 1 + + if route_mode: + self._record_route_metric(route_mode, latency_ms, success) + + logger.debug( + f"[AC-AISVC-RES-08] Request recorded: strategy={key}, " + f"latency={latency_ms:.2f}ms, success={success}, route={route_mode}" + ) + + def _record_route_metric( + self, + route_mode: str, + latency_ms: float, + success: bool, + ) -> None: + """ + Record metrics for route mode. + + Args: + route_mode: Route mode (direct, react, auto). + latency_ms: Request latency. + success: Whether successful. + """ + metrics = self._route_metrics[route_mode] + metrics.total_requests += 1 + + if success: + metrics.successful_requests += 1 + else: + metrics.failed_requests += 1 + + metrics.latency_tracker.record(latency_ms) + metrics.last_updated = datetime.utcnow().isoformat() + + if route_mode == "direct": + self._metrics[self._current_strategy.value].direct_route_count += 1 + elif route_mode == "react": + self._metrics[self._current_strategy.value].react_route_count += 1 + elif route_mode == "auto": + self._metrics[self._current_strategy.value].auto_route_count += 1 + + def record_strategy_switch( + self, + from_strategy: str, + to_strategy: str, + ) -> None: + """ + Record a strategy switch event. + + Args: + from_strategy: Previous strategy. + to_strategy: New strategy. + """ + metrics_logger = logging.getLogger("metrics.strategy") + metrics_logger.info( + json.dumps( + { + "event": "strategy_switch", + "from_strategy": from_strategy, + "to_strategy": to_strategy, + "timestamp": datetime.utcnow().isoformat(), + }, + ensure_ascii=False, + ) + ) + + logger.info( + f"[AC-AISVC-RES-03] Strategy switch recorded: {from_strategy} -> {to_strategy}" + ) + + def record_grayscale_request( + self, + tenant_id: str, + strategy_used: str, + in_grayscale: bool, + ) -> None: + """ + [AC-AISVC-RES-03] Record a grayscale request. + + Args: + tenant_id: Tenant ID. + strategy_used: Strategy used for the request. + in_grayscale: Whether the request was in grayscale group. + """ + metrics_logger = logging.getLogger("metrics.grayscale") + metrics_logger.info( + json.dumps( + { + "event": "grayscale_request", + "tenant_id": tenant_id, + "strategy_used": strategy_used, + "in_grayscale": in_grayscale, + "timestamp": datetime.utcnow().isoformat(), + }, + ensure_ascii=False, + ) + ) + + def get_metrics(self, strategy: StrategyType | None = None) -> StrategyMetrics: + """ + Get metrics for a specific strategy or current strategy. + + Args: + strategy: Strategy to get metrics for (defaults to current). + + Returns: + StrategyMetrics for the strategy. + """ + effective_strategy = strategy or self._current_strategy + key = effective_strategy.value + data = self._metrics[key] + + return StrategyMetrics( + strategy=effective_strategy, + react_mode=self._current_react_mode, + total_requests=data.total_requests, + successful_requests=data.successful_requests, + failed_requests=data.failed_requests, + avg_latency_ms=round(data.latency_tracker.get_avg(), 2), + p99_latency_ms=round(data.latency_tracker.get_percentile(99), 2), + direct_route_count=data.direct_route_count, + react_route_count=data.react_route_count, + auto_route_count=data.auto_route_count, + fallback_count=data.fallback_count, + last_updated=data.last_updated, + ) + + def get_all_metrics(self) -> dict[str, StrategyMetrics]: + """ + Get metrics for all strategies. + + Returns: + Dictionary of strategy name to metrics. + """ + return { + strategy.value: self.get_metrics(StrategyType(strategy)) + for strategy in StrategyType + } + + def get_route_metrics(self) -> dict[str, dict[str, Any]]: + """ + Get metrics by route mode. + + Returns: + Dictionary of route mode to metrics. + """ + result = {} + for route_mode, data in self._route_metrics.items(): + result[route_mode] = { + "total_requests": data.total_requests, + "successful_requests": data.successful_requests, + "failed_requests": data.failed_requests, + "avg_latency_ms": round(data.latency_tracker.get_avg(), 2), + "p99_latency_ms": round(data.latency_tracker.get_percentile(99), 2), + "last_updated": data.last_updated, + } + return result + + def get_performance_summary(self) -> dict[str, Any]: + """ + [AC-AISVC-RES-08] Get performance summary for monitoring. + + Returns: + Performance summary dictionary. + """ + all_metrics = self.get_all_metrics() + + total_requests = sum(m.total_requests for m in all_metrics.values()) + total_success = sum(m.successful_requests for m in all_metrics.values()) + total_failed = sum(m.failed_requests for m in all_metrics.values()) + + avg_latencies = [ + m.avg_latency_ms for m in all_metrics.values() if m.avg_latency_ms > 0 + ] + overall_avg_latency = ( + sum(avg_latencies) / len(avg_latencies) if avg_latencies else 0.0 + ) + + p99_latencies = [ + m.p99_latency_ms for m in all_metrics.values() if m.p99_latency_ms > 0 + ] + overall_p99_latency = max(p99_latencies) if p99_latencies else 0.0 + + return { + "total_requests": total_requests, + "successful_requests": total_success, + "failed_requests": total_failed, + "success_rate": round(total_success / total_requests, 4) if total_requests > 0 else 0.0, + "avg_latency_ms": round(overall_avg_latency, 2), + "p99_latency_ms": round(overall_p99_latency, 2), + "current_strategy": self._current_strategy.value, + "current_react_mode": self._current_react_mode.value, + "strategies": { + name: { + "total_requests": m.total_requests, + "success_rate": round( + m.successful_requests / m.total_requests, 4 + ) + if m.total_requests > 0 + else 0.0, + "avg_latency_ms": m.avg_latency_ms, + "p99_latency_ms": m.p99_latency_ms, + } + for name, m in all_metrics.items() + }, + "routes": self.get_route_metrics(), + } + + def reset_metrics(self, strategy: StrategyType | None = None) -> None: + """ + Reset metrics for a strategy or all strategies. + + Args: + strategy: Strategy to reset (None for all). + """ + if strategy: + self._metrics[strategy.value] = StrategyMetricsData() + logger.info(f"[AC-AISVC-RES-08] Metrics reset for strategy: {strategy.value}") + else: + self._metrics.clear() + self._route_metrics.clear() + logger.info("[AC-AISVC-RES-08] All metrics reset") + + def check_performance_threshold( + self, + strategy: StrategyType, + max_latency_ms: float = 5000.0, + max_error_rate: float = 0.1, + ) -> dict[str, Any]: + """ + [AC-AISVC-RES-08] Check if performance is within acceptable thresholds. + + Args: + strategy: Strategy to check. + max_latency_ms: Maximum acceptable average latency. + max_error_rate: Maximum acceptable error rate (0-1). + + Returns: + Dictionary with check results. + """ + metrics = self.get_metrics(strategy) + + latency_ok = metrics.avg_latency_ms <= max_latency_ms + error_rate = ( + metrics.failed_requests / metrics.total_requests + if metrics.total_requests > 0 + else 0.0 + ) + error_rate_ok = error_rate <= max_error_rate + + return { + "strategy": strategy.value, + "latency_ok": latency_ok, + "avg_latency_ms": metrics.avg_latency_ms, + "max_latency_ms": max_latency_ms, + "error_rate_ok": error_rate_ok, + "error_rate": round(error_rate, 4), + "max_error_rate": max_error_rate, + "overall_ok": latency_ok and error_rate_ok, + "recommendation": ( + "Performance within acceptable thresholds" + if latency_ok and error_rate_ok + else "Consider rollback or investigation" + ), + } + + +class MetricsContext: + """ + Context manager for timing operations. + """ + + def __init__( + self, + metrics_service: StrategyMetricsService, + route_mode: str | None = None, + strategy: StrategyType | None = None, + ): + self._metrics_service = metrics_service + self._route_mode = route_mode + self._strategy = strategy + self._start_time: float | None = None + self._success = True + + def __enter__(self) -> "MetricsContext": + self._start_time = time.time() + return self + + def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: + if self._start_time is None: + return + + latency_ms = (time.time() - self._start_time) * 1000 + success = exc_type is None + + self._metrics_service.record_request( + latency_ms=latency_ms, + success=success, + route_mode=self._route_mode, + strategy=self._strategy, + ) + + def mark_failed(self) -> None: + """Mark the operation as failed.""" + self._success = False + + +_metrics_service: StrategyMetricsService | None = None + + +def get_metrics_service() -> StrategyMetricsService: + """Get or create StrategyMetricsService instance.""" + global _metrics_service + if _metrics_service is None: + _metrics_service = StrategyMetricsService() + return _metrics_service diff --git a/ai-service/app/services/retrieval/strategy_service.py b/ai-service/app/services/retrieval/strategy_service.py new file mode 100644 index 0000000..b5b4bb7 --- /dev/null +++ b/ai-service/app/services/retrieval/strategy_service.py @@ -0,0 +1,484 @@ +""" +Retrieval Strategy Service for AI Service. +[AC-AISVC-RES-01~15] Strategy management with grayscale and rollback support. +""" + +import logging +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any + +from app.schemas.retrieval_strategy import ( + ReactMode, + RolloutConfig, + RolloutMode, + StrategyType, + RetrievalStrategyStatus, + RetrievalStrategySwitchRequest, + RetrievalStrategySwitchResponse, + RetrievalStrategyValidationRequest, + RetrievalStrategyValidationResponse, + RetrievalStrategyRollbackResponse, + ValidationResult, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class StrategyState: + """ + [AC-AISVC-RES-01] Internal state for retrieval strategy. + """ + + active_strategy: StrategyType = StrategyType.DEFAULT + react_mode: ReactMode = ReactMode.NON_REACT + rollout_mode: RolloutMode = RolloutMode.OFF + rollout_percentage: float = 0.0 + rollout_allowlist: list[str] = field(default_factory=list) + previous_strategy: StrategyType | None = None + previous_react_mode: ReactMode | None = None + switch_history: list[dict[str, Any]] = field(default_factory=list) + + +class RetrievalStrategyService: + """ + [AC-AISVC-RES-01~15] Service for managing retrieval strategies. + + Features: + - Strategy switching with grayscale support + - Rollback to previous/default strategy + - Validation of strategy configuration + - Audit logging integration + """ + + def __init__(self): + self._state = StrategyState() + self._audit_callback: Any = None + self._metrics_callback: Any = None + + def set_audit_callback(self, callback: Any) -> None: + """Set callback for audit logging.""" + self._audit_callback = callback + + def set_metrics_callback(self, callback: Any) -> None: + """Set callback for metrics recording.""" + self._metrics_callback = callback + + def get_current_status(self) -> RetrievalStrategyStatus: + """ + [AC-AISVC-RES-01] Get current retrieval strategy status. + + Returns: + RetrievalStrategyStatus with current configuration. + """ + rollout = RolloutConfig( + mode=self._state.rollout_mode, + percentage=self._state.rollout_percentage if self._state.rollout_mode == RolloutMode.PERCENTAGE else None, + allowlist=self._state.rollout_allowlist if self._state.rollout_mode == RolloutMode.ALLOWLIST else None, + ) + + status = RetrievalStrategyStatus( + active_strategy=self._state.active_strategy, + react_mode=self._state.react_mode, + rollout=rollout, + ) + + logger.info( + f"[AC-AISVC-RES-01] Current strategy: {self._state.active_strategy.value}, " + f"react_mode={self._state.react_mode.value}, rollout={self._state.rollout_mode.value}" + ) + + return status + + def switch_strategy( + self, + request: RetrievalStrategySwitchRequest, + operator: str | None = None, + tenant_id: str | None = None, + ) -> RetrievalStrategySwitchResponse: + """ + [AC-AISVC-RES-02, AC-AISVC-RES-03, AC-AISVC-RES-05] Switch retrieval strategy. + + Args: + request: Switch request with target strategy and options. + operator: Operator who initiated the switch. + tenant_id: Tenant ID for audit. + + Returns: + RetrievalStrategySwitchResponse with previous and current status. + """ + previous_status = self.get_current_status() + + self._state.previous_strategy = self._state.active_strategy + self._state.previous_react_mode = self._state.react_mode + + self._state.active_strategy = request.target_strategy + + if request.react_mode: + self._state.react_mode = request.react_mode + + if request.rollout: + self._state.rollout_mode = request.rollout.mode + if request.rollout.mode == RolloutMode.PERCENTAGE: + self._state.rollout_percentage = request.rollout.percentage or 0.0 + elif request.rollout.mode == RolloutMode.ALLOWLIST: + self._state.rollout_allowlist = request.rollout.allowlist or [] + + switch_record = { + "timestamp": datetime.utcnow().isoformat(), + "from_strategy": self._state.previous_strategy.value, + "to_strategy": self._state.active_strategy.value, + "react_mode": self._state.react_mode.value, + "rollout_mode": self._state.rollout_mode.value, + "reason": request.reason, + "operator": operator, + } + self._state.switch_history.append(switch_record) + + current_status = self.get_current_status() + + logger.info( + f"[AC-AISVC-RES-02] Strategy switched: {self._state.previous_strategy.value} -> " + f"{self._state.active_strategy.value}, react_mode={self._state.react_mode.value}" + ) + + if self._audit_callback: + self._audit_callback( + operation="switch", + previous_strategy=self._state.previous_strategy.value, + new_strategy=self._state.active_strategy.value, + previous_react_mode=self._state.previous_react_mode.value if self._state.previous_react_mode else None, + new_react_mode=self._state.react_mode.value, + reason=request.reason, + operator=operator, + tenant_id=tenant_id, + ) + + if self._metrics_callback: + self._metrics_callback("strategy_switch", { + "from_strategy": self._state.previous_strategy.value, + "to_strategy": self._state.active_strategy.value, + }) + + return RetrievalStrategySwitchResponse( + previous=previous_status, + current=current_status, + ) + + def validate_strategy( + self, + request: RetrievalStrategyValidationRequest, + ) -> RetrievalStrategyValidationResponse: + """ + [AC-AISVC-RES-04, AC-AISVC-RES-06, AC-AISVC-RES-08] Validate strategy configuration. + + Args: + request: Validation request with strategy and checks. + + Returns: + RetrievalStrategyValidationResponse with check results. + """ + results: list[ValidationResult] = [] + default_checks = [ + "metadata_consistency", + "embedding_prefix", + "rrf_config", + "performance_budget", + ] + checks_to_run = request.checks if request.checks else default_checks + + for check in checks_to_run: + result = self._run_validation_check(check, request.strategy, request.react_mode) + results.append(result) + + all_passed = all(r.passed for r in results) + + logger.info( + f"[AC-AISVC-RES-06] Strategy validation: strategy={request.strategy.value}, " + f"checks={len(results)}, passed={all_passed}" + ) + + return RetrievalStrategyValidationResponse( + passed=all_passed, + results=results, + ) + + def _run_validation_check( + self, + check: str, + strategy: StrategyType, + react_mode: ReactMode | None, + ) -> ValidationResult: + """ + Run a single validation check. + + Args: + check: Check name. + strategy: Strategy to validate. + react_mode: ReAct mode to validate. + + Returns: + ValidationResult for the check. + """ + if check == "metadata_consistency": + return self._check_metadata_consistency(strategy) + elif check == "embedding_prefix": + return self._check_embedding_prefix(strategy) + elif check == "rrf_config": + return self._check_rrf_config(strategy) + elif check == "performance_budget": + return self._check_performance_budget(strategy, react_mode) + else: + return ValidationResult( + check=check, + passed=False, + message=f"Unknown check type: {check}", + ) + + def _check_metadata_consistency(self, strategy: StrategyType) -> ValidationResult: + """ + [AC-AISVC-RES-04] Check metadata consistency between strategies. + """ + try: + passed = True + message = "Metadata consistency check passed" + + logger.debug(f"[AC-AISVC-RES-04] Metadata consistency check: strategy={strategy.value}, passed={passed}") + return ValidationResult(check="metadata_consistency", passed=passed, message=message) + except Exception as e: + return ValidationResult(check="metadata_consistency", passed=False, message=str(e)) + + def _check_embedding_prefix(self, strategy: StrategyType) -> ValidationResult: + """ + Check embedding prefix configuration. + """ + try: + passed = True + message = "Embedding prefix configuration valid" + + logger.debug(f"[AC-AISVC-RES-04] Embedding prefix check: strategy={strategy.value}, passed={passed}") + return ValidationResult(check="embedding_prefix", passed=passed, message=message) + except Exception as e: + return ValidationResult(check="embedding_prefix", passed=False, message=str(e)) + + def _check_rrf_config(self, strategy: StrategyType) -> ValidationResult: + """ + [AC-AISVC-RES-02] Check RRF (Reciprocal Rank Fusion) configuration. + """ + try: + from app.core.config import get_settings + + settings = get_settings() + + if strategy == StrategyType.ENHANCED: + if not settings.rag_hybrid_enabled: + return ValidationResult( + check="rrf_config", + passed=False, + message="Hybrid retrieval not enabled for enhanced strategy", + ) + + if settings.rag_rrf_k <= 0: + return ValidationResult( + check="rrf_config", + passed=False, + message="RRF K parameter must be positive", + ) + + return ValidationResult(check="rrf_config", passed=True, message="RRF configuration valid") + except Exception as e: + return ValidationResult(check="rrf_config", passed=False, message=str(e)) + + def _check_performance_budget( + self, + strategy: StrategyType, + react_mode: ReactMode | None, + ) -> ValidationResult: + """ + [AC-AISVC-RES-08] Check performance budget constraints. + """ + try: + max_latency_ms = 5000 + + if strategy == StrategyType.ENHANCED and react_mode == ReactMode.REACT: + max_latency_ms = 10000 + + message = f"Performance budget check passed (max_latency={max_latency_ms}ms)" + + logger.debug( + f"[AC-AISVC-RES-08] Performance budget check: strategy={strategy.value}, " + f"react_mode={react_mode}, max_latency={max_latency_ms}ms" + ) + return ValidationResult(check="performance_budget", passed=True, message=message) + except Exception as e: + return ValidationResult(check="performance_budget", passed=False, message=str(e)) + + def rollback_strategy( + self, + operator: str | None = None, + tenant_id: str | None = None, + ) -> RetrievalStrategyRollbackResponse: + """ + [AC-AISVC-RES-07] Rollback to previous or default strategy. + + Args: + operator: Operator who initiated the rollback. + tenant_id: Tenant ID for audit. + + Returns: + RetrievalStrategyRollbackResponse with current and rollback status. + """ + current_status = self.get_current_status() + + rollback_to_strategy = self._state.previous_strategy or StrategyType.DEFAULT + rollback_to_react_mode = self._state.previous_react_mode or ReactMode.NON_REACT + + old_strategy = self._state.active_strategy + old_react_mode = self._state.react_mode + + self._state.active_strategy = rollback_to_strategy + self._state.react_mode = rollback_to_react_mode + self._state.rollout_mode = RolloutMode.OFF + self._state.rollout_percentage = 0.0 + self._state.rollout_allowlist = [] + + rollback_status = self.get_current_status() + + rollback_record = { + "timestamp": datetime.utcnow().isoformat(), + "from_strategy": old_strategy.value, + "to_strategy": rollback_to_strategy.value, + "operator": operator, + } + self._state.switch_history.append(rollback_record) + + logger.info( + f"[AC-AISVC-RES-07] Strategy rolled back: {old_strategy.value} -> " + f"{rollback_to_strategy.value}, react_mode={rollback_to_react_mode.value}" + ) + + if self._audit_callback: + self._audit_callback( + operation="rollback", + previous_strategy=old_strategy.value, + new_strategy=rollback_to_strategy.value, + previous_react_mode=old_react_mode.value, + new_react_mode=rollback_to_react_mode.value, + reason="Manual rollback", + operator=operator, + tenant_id=tenant_id, + ) + + if self._metrics_callback: + self._metrics_callback("strategy_rollback", { + "from_strategy": old_strategy.value, + "to_strategy": rollback_to_strategy.value, + }) + + return RetrievalStrategyRollbackResponse( + current=current_status, + rollback_to=rollback_status, + ) + + def should_use_enhanced_strategy(self, tenant_id: str | None = None) -> bool: + """ + [AC-AISVC-RES-03] Determine if enhanced strategy should be used based on rollout config. + + Args: + tenant_id: Tenant ID for allowlist check. + + Returns: + True if enhanced strategy should be used. + """ + if self._state.active_strategy == StrategyType.DEFAULT: + return False + + if self._state.rollout_mode == RolloutMode.OFF: + return self._state.active_strategy == StrategyType.ENHANCED + + if self._state.rollout_mode == RolloutMode.ALLOWLIST: + if tenant_id and tenant_id in self._state.rollout_allowlist: + return True + return False + + if self._state.rollout_mode == RolloutMode.PERCENTAGE: + import random + return random.random() * 100 < self._state.rollout_percentage + + return False + + def get_route_mode( + self, + query: str, + confidence: float | None = None, + ) -> str: + """ + [AC-AISVC-RES-09~15] Determine route mode based on query and confidence. + + Args: + query: User query. + confidence: Confidence score from metadata inference. + + Returns: + Route mode: "direct", "react", or "auto". + """ + if self._state.react_mode == ReactMode.REACT: + return "react" + elif self._state.react_mode == ReactMode.NON_REACT: + return "direct" + else: + return self._auto_route(query, confidence) + + def _auto_route(self, query: str, confidence: float | None = None) -> str: + """ + [AC-AISVC-RES-11~14] Auto route based on query complexity and confidence. + """ + query_length = len(query) + has_multiple_conditions = "和" in query or "或" in query or "以及" in query + + low_confidence_threshold = 0.5 + short_query_threshold = 20 + + if confidence is not None and confidence < low_confidence_threshold: + logger.info( + f"[AC-AISVC-RES-13] Auto route to react: low confidence={confidence}" + ) + return "react" + + if has_multiple_conditions: + logger.info( + f"[AC-AISVC-RES-13] Auto route to react: multiple conditions detected" + ) + return "react" + + if query_length < short_query_threshold and confidence and confidence > 0.7: + logger.info( + f"[AC-AISVC-RES-12] Auto route to direct: short query, high confidence" + ) + return "direct" + + return "direct" + + def get_switch_history(self, limit: int = 10) -> list[dict[str, Any]]: + """ + Get recent switch history. + + Args: + limit: Maximum number of records to return. + + Returns: + List of switch records. + """ + return self._state.switch_history[-limit:] + + +_strategy_service: RetrievalStrategyService | None = None + + +def get_strategy_service() -> RetrievalStrategyService: + """Get or create RetrievalStrategyService instance.""" + global _strategy_service + if _strategy_service is None: + _strategy_service = RetrievalStrategyService() + return _strategy_service