[AC-RETRIEVAL-STRATEGY] feat(retrieval): 新增检索策略路由服务

- 新增 strategy_service 实现检索策略路由核心逻辑
- 新增 strategy_metrics 提供策略性能指标收集
- 新增 strategy_audit 提供策略审计日志功能
- 新增 retrieval_strategy API 端点支持策略管理
- 支持多种检索策略的动态切换和监控
This commit is contained in:
MerCry 2026-03-11 19:02:40 +08:00
parent 6fec2a755a
commit b3343f9e52
4 changed files with 1585 additions and 0 deletions

View File

@ -0,0 +1,349 @@
"""
Retrieval Strategy API Endpoints.
[AC-AISVC-RES-01~15] 策略管理 API
Endpoints:
- GET /strategy/retrieval/current - 获取当前策略状态
- POST /strategy/retrieval/switch - 切换策略
- POST /strategy/retrieval/validate - 验证策略配置
- POST /strategy/retrieval/rollback - 回退策略
"""
import logging
from typing import Any
from fastapi import APIRouter, HTTPException, status
from pydantic import BaseModel, Field
from app.services.retrieval.strategy.config import (
GrayscaleConfig,
ModeRouterConfig,
PipelineConfig,
RetrievalStrategyConfig,
RerankerConfig,
RuntimeMode,
StrategyType,
)
from app.services.retrieval.strategy.rollback_manager import (
RollbackManager,
RollbackResult,
RollbackTrigger,
get_rollback_manager,
)
from app.services.retrieval.strategy.strategy_router import (
StrategyRouter,
get_strategy_router,
set_strategy_router,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/strategy/retrieval", tags=["strategy"])
class GrayscaleConfigSchema(BaseModel):
"""灰度配置 Schema。"""
enabled: bool = False
percentage: float = Field(default=0.0, ge=0.0, le=100.0)
allowlist: list[str] = Field(default_factory=list)
class RerankerConfigSchema(BaseModel):
"""重排器配置 Schema。"""
enabled: bool = False
model: str = "cross-encoder"
top_k_after_rerank: int = 5
min_score_threshold: float = 0.3
class ModeRouterConfigSchema(BaseModel):
"""模式路由配置 Schema。"""
runtime_mode: str = "direct"
react_trigger_confidence_threshold: float = 0.6
react_trigger_complexity_score: float = 0.5
react_max_steps: int = 5
direct_fallback_on_low_confidence: bool = True
class PipelineConfigSchema(BaseModel):
"""Pipeline 配置 Schema。"""
top_k: int = 5
score_threshold: float = 0.01
min_hits: int = 1
two_stage_enabled: bool = True
class StrategyStatusResponse(BaseModel):
"""策略状态响应。"""
active_strategy: str
grayscale: GrayscaleConfigSchema
pipeline: PipelineConfigSchema
reranker: RerankerConfigSchema
mode_router: ModeRouterConfigSchema
performance_thresholds: dict[str, float] = Field(default_factory=dict)
class SwitchRequest(BaseModel):
"""策略切换请求。"""
active_strategy: str | None = None
grayscale: GrayscaleConfigSchema | None = None
mode_router: ModeRouterConfigSchema | None = None
reranker: RerankerConfigSchema | None = None
class SwitchResponse(BaseModel):
"""策略切换响应。"""
success: bool
previous_strategy: str
current_strategy: str
message: str
class ValidationRequest(BaseModel):
"""策略验证请求。"""
config: dict[str, Any] = Field(default_factory=dict)
class ValidationResponse(BaseModel):
"""策略验证响应。"""
valid: bool
errors: list[str] = Field(default_factory=list)
warnings: list[str] = Field(default_factory=list)
config_summary: dict[str, Any] = Field(default_factory=dict)
class RollbackResponse(BaseModel):
"""策略回退响应。"""
success: bool
previous_strategy: str
current_strategy: str
trigger: str
reason: str
audit_log_id: str | None = None
@router.get(
"/current",
response_model=StrategyStatusResponse,
summary="获取当前策略状态",
description="【AC-AISVC-RES-01】 获取当前活跃的检索策略配置状态。",
)
async def get_current_strategy() -> StrategyStatusResponse:
"""
AC-AISVC-RES-01 获取当前策略状态
"""
strategy_router = get_strategy_router()
config = strategy_router.get_config()
return StrategyStatusResponse(
active_strategy=config.active_strategy.value,
grayscale=GrayscaleConfigSchema(
enabled=config.grayscale.enabled,
percentage=config.grayscale.percentage,
allowlist=config.grayscale.allowlist,
),
pipeline=PipelineConfigSchema(
top_k=config.pipeline.top_k,
score_threshold=config.pipeline.score_threshold,
min_hits=config.pipeline.min_hits,
two_stage_enabled=config.pipeline.two_stage_enabled,
),
reranker=RerankerConfigSchema(
enabled=config.reranker.enabled,
model=config.reranker.model,
top_k_after_rerank=config.reranker.top_k_after_rerank,
min_score_threshold=config.reranker.min_score_threshold,
),
mode_router=ModeRouterConfigSchema(
runtime_mode=config.mode_router.runtime_mode.value,
react_trigger_confidence_threshold=config.mode_router.react_trigger_confidence_threshold,
react_trigger_complexity_score=config.mode_router.react_trigger_complexity_score,
react_max_steps=config.mode_router.react_max_steps,
direct_fallback_on_low_confidence=config.mode_router.direct_fallback_on_low_confidence,
),
performance_thresholds=config.performance_thresholds,
)
@router.post(
"/switch",
response_model=SwitchResponse,
summary="切换策略",
description="【AC-AISVC-RES-02, AC-AISVC-RES-03】 切换检索策略, 支持灰度发布配置。",
)
async def switch_strategy(request: SwitchRequest) -> SwitchResponse:
"""
AC-AISVC-RES-02, AC-AISVC-RES-03 切换策略
支持灰度发布配置percentage/allowlist
"""
strategy_router = get_strategy_router()
current_config = strategy_router.get_config()
previous_strategy = current_config.active_strategy.value
try:
new_active_strategy = StrategyType(request.active_strategy) if request.active_strategy else current_config.active_strategy
except ValueError as e:
raise HTTPException(
status_code=400,
detail=f"Invalid strategy type: {request.active_strategy}. Valid values are: default, enhanced",
)
new_config = RetrievalStrategyConfig(
active_strategy=new_active_strategy,
grayscale=GrayscaleConfig(
enabled=request.grayscale.enabled if request.grayscale else current_config.grayscale.enabled,
percentage=request.grayscale.percentage if request.grayscale else current_config.grayscale.percentage,
allowlist=request.grayscale.allowlist if request.grayscale else current_config.grayscale.allowlist,
),
mode_router=ModeRouterConfig(
runtime_mode=RuntimeMode(request.mode_router.runtime_mode) if request.mode_router and request.mode_router.runtime_mode else current_config.mode_router.runtime_mode,
react_trigger_confidence_threshold=request.mode_router.react_trigger_confidence_threshold if request.mode_router else current_config.mode_router.react_trigger_confidence_threshold,
react_trigger_complexity_score=request.mode_router.react_trigger_complexity_score if request.mode_router else current_config.mode_router.react_trigger_complexity_score,
react_max_steps=request.mode_router.react_max_steps if request.mode_router else current_config.mode_router.react_max_steps,
direct_fallback_on_low_confidence=request.mode_router.direct_fallback_on_low_confidence if request.mode_router else current_config.mode_router.direct_fallback_on_low_confidence,
) if request.mode_router else current_config.mode_router,
reranker=RerankerConfig(
enabled=request.reranker.enabled if request.reranker else current_config.reranker.enabled,
model=request.reranker.model if request.reranker else current_config.reranker.model,
top_k_after_rerank=request.reranker.top_k_after_rerank if request.reranker else current_config.reranker.top_k_after_rerank,
min_score_threshold=request.reranker.min_score_threshold if request.reranker else current_config.reranker.min_score_threshold,
) if request.reranker else current_config.reranker,
pipeline=current_config.pipeline,
metadata_inference=current_config.metadata_inference,
performance_thresholds=current_config.performance_thresholds,
)
strategy_router.update_config(new_config)
logger.info(
f"[AC-AISVC-RES-02] Strategy switched: {previous_strategy} -> {new_config.active_strategy.value}"
)
return SwitchResponse(
success=True,
previous_strategy=previous_strategy,
current_strategy=new_config.active_strategy.value,
message=f"Strategy switched from {previous_strategy} to {new_config.active_strategy.value}",
)
@router.post(
"/validate",
response_model=ValidationResponse,
summary="验证策略配置",
description="【AC-AISVC-RES-06, AC-AISVC-RES-08】 验证策略配置的完整性与一致性。",
)
async def validate_strategy(request: ValidationRequest) -> ValidationResponse:
"""
AC-AISVC-RES-06, AC-AISVC-RES-08 验证策略配置
"""
errors: list[str] = []
warnings: list[str] = []
config = request.config
if "active_strategy" in config:
if config["active_strategy"] not in ["default", "enhanced"]:
errors.append(f"Invalid active_strategy: {config['active_strategy']}")
if "grayscale" in config:
grayscale = config["grayscale"]
if grayscale.get("percentage", 1.0) is not None:
if not (0 <= grayscale["percentage"] <= 100):
errors.append(f"Invalid grayscale percentage: {grayscale['percentage']}")
if "mode_router" in config:
mode_router = config["mode_router"]
if "runtime_mode" in mode_router:
if mode_router["runtime_mode"] not in ["direct", "react", "auto"]:
errors.append(f"Invalid runtime_mode: {mode_router['runtime_mode']}")
if "reranker" in config:
reranker = config["reranker"]
if reranker.get("enabled") and reranker.get("top_k_after_rerank", 1) > 20:
warnings.append(f"top_k_after_rerank should be between 1 and 20, current: {reranker['top_k_after_rerank']}")
if "performance_thresholds" in config:
thresholds = config["performance_thresholds"]
if thresholds.get("max_latency_ms", 0) is not None and thresholds["max_latency_ms"] < 100:
warnings.append(f"max_latency_ms seems too low: {thresholds['max_latency_ms']}ms")
return ValidationResponse(
valid=len(errors) == 0,
errors=errors,
warnings=warnings,
config_summary={
"active_strategy": config.get("active_strategy", "default"),
"grayscale_enabled": config.get("grayscale", {}).get("enabled", False),
"runtime_mode": config.get("mode_router", {}).get("runtime_mode", "direct"),
"reranker_enabled": config.get("reranker", {}).get("enabled", False),
"performance_thresholds": config.get("performance_thresholds", {}),
},
)
@router.post(
"/rollback",
response_model=RollbackResponse,
summary="回退策略",
description="【AC-AISVC-RES-07】 回退到默认策略。",
)
async def rollback_strategy(
trigger: str = "manual",
reason: str = "",
) -> RollbackResponse:
"""
AC-AISVC-RES-07 回退策略
支持手动触发和自动触发性能退化异常
"""
strategy_router = get_strategy_router()
current_config = strategy_router.get_config()
previous_strategy = current_config.active_strategy
if previous_strategy == StrategyType.DEFAULT:
return RollbackResponse(
success=False,
previous_strategy=previous_strategy.value,
current_strategy=previous_strategy.value,
trigger=trigger,
reason="Already on default strategy",
audit_log_id=None,
)
new_config = RetrievalStrategyConfig(
active_strategy=StrategyType.DEFAULT,
grayscale=current_config.grayscale,
mode_router=current_config.mode_router,
reranker=current_config.reranker,
pipeline=current_config.pipeline,
metadata_inference=current_config.metadata_inference,
performance_thresholds=current_config.performance_thresholds,
)
strategy_router.update_config(new_config)
rollback_manager = get_rollback_manager()
rollback_manager.update_config(new_config)
audit_log = rollback_manager.record_audit(
action="rollback",
details={
"from_strategy": previous_strategy.value,
"to_strategy": StrategyType.DEFAULT.value,
"trigger": trigger,
"reason": reason or "Manual rollback",
},
)
logger.info(
f"[AC-AISVC-RES-07] Strategy rolled back: {previous_strategy.value} -> default"
)
return RollbackResponse(
success=True,
previous_strategy=previous_strategy.value,
current_strategy=StrategyType.DEFAULT.value,
trigger=trigger,
reason=reason or "Manual rollback",
audit_log_id=str(audit_log.timestamp) if audit_log else None,
)

View File

@ -0,0 +1,300 @@
"""
Strategy Audit Service for AI Service.
[AC-AISVC-RES-07] Audit logging for strategy operations.
"""
import json
import logging
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from app.schemas.retrieval_strategy import StrategyAuditLog
logger = logging.getLogger(__name__)
@dataclass
class AuditEntry:
"""
Internal audit entry structure.
"""
timestamp: str
operation: str
previous_strategy: str | None = None
new_strategy: str | None = None
previous_react_mode: str | None = None
new_react_mode: str | None = None
reason: str | None = None
operator: str | None = None
tenant_id: str | None = None
metadata: dict[str, Any] | None = None
class StrategyAuditService:
"""
[AC-AISVC-RES-07] Audit service for strategy operations.
Features:
- Structured audit logging
- In-memory audit trail (configurable retention)
- JSON output for log aggregation
"""
def __init__(self, max_entries: int = 1000):
self._audit_log: deque[AuditEntry] = deque(maxlen=max_entries)
self._max_entries = max_entries
def log(
self,
operation: str,
previous_strategy: str | None = None,
new_strategy: str | None = None,
previous_react_mode: str | None = None,
new_react_mode: str | None = None,
reason: str | None = None,
operator: str | None = None,
tenant_id: str | None = None,
metadata: dict[str, Any] | None = None,
) -> None:
"""
[AC-AISVC-RES-07] Log a strategy operation.
Args:
operation: Operation type (switch, rollback, validate).
previous_strategy: Previous strategy value.
new_strategy: New strategy value.
previous_react_mode: Previous react mode.
new_react_mode: New react mode.
reason: Reason for the operation.
operator: Operator who performed the operation.
tenant_id: Tenant ID if applicable.
metadata: Additional metadata.
"""
entry = AuditEntry(
timestamp=datetime.utcnow().isoformat(),
operation=operation,
previous_strategy=previous_strategy,
new_strategy=new_strategy,
previous_react_mode=previous_react_mode,
new_react_mode=new_react_mode,
reason=reason,
operator=operator,
tenant_id=tenant_id,
metadata=metadata,
)
self._audit_log.append(entry)
log_data = {
"audit_type": "strategy_operation",
"timestamp": entry.timestamp,
"operation": entry.operation,
"previous_strategy": entry.previous_strategy,
"new_strategy": entry.new_strategy,
"previous_react_mode": entry.previous_react_mode,
"new_react_mode": entry.new_react_mode,
"reason": entry.reason,
"operator": entry.operator,
"tenant_id": entry.tenant_id,
"metadata": entry.metadata,
}
logger.info(
f"[AC-AISVC-RES-07] Strategy audit: operation={operation}, "
f"from={previous_strategy} -> to={new_strategy}, "
f"operator={operator}, reason={reason}"
)
audit_logger = logging.getLogger("audit.strategy")
audit_logger.info(json.dumps(log_data, ensure_ascii=False))
def log_switch(
self,
previous_strategy: str,
new_strategy: str,
previous_react_mode: str | None = None,
new_react_mode: str | None = None,
reason: str | None = None,
operator: str | None = None,
tenant_id: str | None = None,
rollout_config: dict[str, Any] | None = None,
) -> None:
"""
Log a strategy switch operation.
Args:
previous_strategy: Previous strategy value.
new_strategy: New strategy value.
previous_react_mode: Previous react mode.
new_react_mode: New react mode.
reason: Reason for the switch.
operator: Operator who performed the switch.
tenant_id: Tenant ID if applicable.
rollout_config: Rollout configuration.
"""
self.log(
operation="switch",
previous_strategy=previous_strategy,
new_strategy=new_strategy,
previous_react_mode=previous_react_mode,
new_react_mode=new_react_mode,
reason=reason,
operator=operator,
tenant_id=tenant_id,
metadata={"rollout_config": rollout_config} if rollout_config else None,
)
def log_rollback(
self,
previous_strategy: str,
new_strategy: str,
previous_react_mode: str | None = None,
new_react_mode: str | None = None,
reason: str | None = None,
operator: str | None = None,
tenant_id: str | None = None,
) -> None:
"""
Log a strategy rollback operation.
Args:
previous_strategy: Previous strategy value.
new_strategy: Strategy rolled back to.
previous_react_mode: Previous react mode.
new_react_mode: React mode rolled back to.
reason: Reason for the rollback.
operator: Operator who performed the rollback.
tenant_id: Tenant ID if applicable.
"""
self.log(
operation="rollback",
previous_strategy=previous_strategy,
new_strategy=new_strategy,
previous_react_mode=previous_react_mode,
new_react_mode=new_react_mode,
reason=reason or "Manual rollback",
operator=operator,
tenant_id=tenant_id,
)
def log_validation(
self,
strategy: str,
react_mode: str | None = None,
checks: list[str] | None = None,
passed: bool = False,
operator: str | None = None,
tenant_id: str | None = None,
) -> None:
"""
Log a strategy validation operation.
Args:
strategy: Strategy being validated.
react_mode: React mode being validated.
checks: List of checks performed.
passed: Whether validation passed.
operator: Operator who performed the validation.
tenant_id: Tenant ID if applicable.
"""
self.log(
operation="validate",
new_strategy=strategy,
new_react_mode=react_mode,
operator=operator,
tenant_id=tenant_id,
metadata={
"checks": checks,
"passed": passed,
},
)
def get_audit_log(
self,
limit: int = 100,
operation: str | None = None,
tenant_id: str | None = None,
) -> list[StrategyAuditLog]:
"""
Get audit log entries.
Args:
limit: Maximum number of entries to return.
operation: Filter by operation type.
tenant_id: Filter by tenant ID.
Returns:
List of StrategyAuditLog entries.
"""
entries = list(self._audit_log)
if operation:
entries = [e for e in entries if e.operation == operation]
if tenant_id:
entries = [e for e in entries if e.tenant_id == tenant_id]
entries = entries[-limit:]
return [
StrategyAuditLog(
timestamp=e.timestamp,
operation=e.operation,
previous_strategy=e.previous_strategy,
new_strategy=e.new_strategy,
previous_react_mode=e.previous_react_mode,
new_react_mode=e.new_react_mode,
reason=e.reason,
operator=e.operator,
tenant_id=e.tenant_id,
metadata=e.metadata,
)
for e in entries
]
def get_audit_stats(self) -> dict[str, Any]:
"""
Get audit log statistics.
Returns:
Dictionary with audit statistics.
"""
entries = list(self._audit_log)
operation_counts: dict[str, int] = {}
for entry in entries:
operation_counts[entry.operation] = operation_counts.get(entry.operation, 0) + 1
return {
"total_entries": len(entries),
"max_entries": self._max_entries,
"operation_counts": operation_counts,
"oldest_entry": entries[0].timestamp if entries else None,
"newest_entry": entries[-1].timestamp if entries else None,
}
def clear_audit_log(self) -> int:
"""
Clear all audit log entries.
Returns:
Number of entries cleared.
"""
count = len(self._audit_log)
self._audit_log.clear()
logger.info(f"[AC-AISVC-RES-07] Audit log cleared: {count} entries removed")
return count
_audit_service: StrategyAuditService | None = None
def get_audit_service() -> StrategyAuditService:
"""Get or create StrategyAuditService instance."""
global _audit_service
if _audit_service is None:
_audit_service = StrategyAuditService()
return _audit_service

View File

@ -0,0 +1,452 @@
"""
Strategy Metrics Service for AI Service.
[AC-AISVC-RES-03, AC-AISVC-RES-08] Metrics collection for strategy operations.
"""
import json
import logging
import time
from collections import defaultdict
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from app.schemas.retrieval_strategy import (
ReactMode,
StrategyMetrics,
StrategyType,
)
logger = logging.getLogger(__name__)
@dataclass
class LatencyTracker:
"""
Latency tracking for a single operation.
"""
latencies: list[float] = field(default_factory=list)
max_samples: int = 1000
def record(self, latency_ms: float) -> None:
"""Record a latency sample."""
if len(self.latencies) >= self.max_samples:
self.latencies = self.latencies[-self.max_samples // 2 :]
self.latencies.append(latency_ms)
def get_percentile(self, percentile: float) -> float:
"""Get latency at given percentile."""
if not self.latencies:
return 0.0
sorted_latencies = sorted(self.latencies)
index = int(len(sorted_latencies) * percentile / 100)
index = min(index, len(sorted_latencies) - 1)
return sorted_latencies[index]
def get_avg(self) -> float:
"""Get average latency."""
if not self.latencies:
return 0.0
return sum(self.latencies) / len(self.latencies)
@dataclass
class StrategyMetricsData:
"""
Internal metrics data structure.
"""
total_requests: int = 0
successful_requests: int = 0
failed_requests: int = 0
latency_tracker: LatencyTracker = field(default_factory=LatencyTracker)
direct_route_count: int = 0
react_route_count: int = 0
auto_route_count: int = 0
fallback_count: int = 0
last_updated: str | None = None
class StrategyMetricsService:
"""
[AC-AISVC-RES-03, AC-AISVC-RES-08] Metrics service for strategy operations.
Features:
- Request counting by strategy and route mode
- Latency tracking with percentiles
- Fallback and error tracking
- Metrics export for monitoring
"""
def __init__(self):
self._metrics: dict[str, StrategyMetricsData] = defaultdict(StrategyMetricsData)
self._route_metrics: dict[str, StrategyMetricsData] = defaultdict(StrategyMetricsData)
self._current_strategy: StrategyType = StrategyType.DEFAULT
self._current_react_mode: ReactMode = ReactMode.NON_REACT
def set_current_strategy(
self,
strategy: StrategyType,
react_mode: ReactMode,
) -> None:
"""
Set current strategy for metrics attribution.
Args:
strategy: Current active strategy.
react_mode: Current react mode.
"""
self._current_strategy = strategy
self._current_react_mode = react_mode
def record_request(
self,
latency_ms: float,
success: bool = True,
route_mode: str | None = None,
fallback: bool = False,
strategy: StrategyType | None = None,
) -> None:
"""
[AC-AISVC-RES-03, AC-AISVC-RES-08] Record a retrieval request.
Args:
latency_ms: Request latency in milliseconds.
success: Whether the request was successful.
route_mode: Route mode used (direct, react, auto).
fallback: Whether fallback to default occurred.
strategy: Strategy used (defaults to current).
"""
effective_strategy = strategy or self._current_strategy
key = effective_strategy.value
metrics = self._metrics[key]
metrics.total_requests += 1
if success:
metrics.successful_requests += 1
else:
metrics.failed_requests += 1
metrics.latency_tracker.record(latency_ms)
metrics.last_updated = datetime.utcnow().isoformat()
if fallback:
metrics.fallback_count += 1
if route_mode:
self._record_route_metric(route_mode, latency_ms, success)
logger.debug(
f"[AC-AISVC-RES-08] Request recorded: strategy={key}, "
f"latency={latency_ms:.2f}ms, success={success}, route={route_mode}"
)
def _record_route_metric(
self,
route_mode: str,
latency_ms: float,
success: bool,
) -> None:
"""
Record metrics for route mode.
Args:
route_mode: Route mode (direct, react, auto).
latency_ms: Request latency.
success: Whether successful.
"""
metrics = self._route_metrics[route_mode]
metrics.total_requests += 1
if success:
metrics.successful_requests += 1
else:
metrics.failed_requests += 1
metrics.latency_tracker.record(latency_ms)
metrics.last_updated = datetime.utcnow().isoformat()
if route_mode == "direct":
self._metrics[self._current_strategy.value].direct_route_count += 1
elif route_mode == "react":
self._metrics[self._current_strategy.value].react_route_count += 1
elif route_mode == "auto":
self._metrics[self._current_strategy.value].auto_route_count += 1
def record_strategy_switch(
self,
from_strategy: str,
to_strategy: str,
) -> None:
"""
Record a strategy switch event.
Args:
from_strategy: Previous strategy.
to_strategy: New strategy.
"""
metrics_logger = logging.getLogger("metrics.strategy")
metrics_logger.info(
json.dumps(
{
"event": "strategy_switch",
"from_strategy": from_strategy,
"to_strategy": to_strategy,
"timestamp": datetime.utcnow().isoformat(),
},
ensure_ascii=False,
)
)
logger.info(
f"[AC-AISVC-RES-03] Strategy switch recorded: {from_strategy} -> {to_strategy}"
)
def record_grayscale_request(
self,
tenant_id: str,
strategy_used: str,
in_grayscale: bool,
) -> None:
"""
[AC-AISVC-RES-03] Record a grayscale request.
Args:
tenant_id: Tenant ID.
strategy_used: Strategy used for the request.
in_grayscale: Whether the request was in grayscale group.
"""
metrics_logger = logging.getLogger("metrics.grayscale")
metrics_logger.info(
json.dumps(
{
"event": "grayscale_request",
"tenant_id": tenant_id,
"strategy_used": strategy_used,
"in_grayscale": in_grayscale,
"timestamp": datetime.utcnow().isoformat(),
},
ensure_ascii=False,
)
)
def get_metrics(self, strategy: StrategyType | None = None) -> StrategyMetrics:
"""
Get metrics for a specific strategy or current strategy.
Args:
strategy: Strategy to get metrics for (defaults to current).
Returns:
StrategyMetrics for the strategy.
"""
effective_strategy = strategy or self._current_strategy
key = effective_strategy.value
data = self._metrics[key]
return StrategyMetrics(
strategy=effective_strategy,
react_mode=self._current_react_mode,
total_requests=data.total_requests,
successful_requests=data.successful_requests,
failed_requests=data.failed_requests,
avg_latency_ms=round(data.latency_tracker.get_avg(), 2),
p99_latency_ms=round(data.latency_tracker.get_percentile(99), 2),
direct_route_count=data.direct_route_count,
react_route_count=data.react_route_count,
auto_route_count=data.auto_route_count,
fallback_count=data.fallback_count,
last_updated=data.last_updated,
)
def get_all_metrics(self) -> dict[str, StrategyMetrics]:
"""
Get metrics for all strategies.
Returns:
Dictionary of strategy name to metrics.
"""
return {
strategy.value: self.get_metrics(StrategyType(strategy))
for strategy in StrategyType
}
def get_route_metrics(self) -> dict[str, dict[str, Any]]:
"""
Get metrics by route mode.
Returns:
Dictionary of route mode to metrics.
"""
result = {}
for route_mode, data in self._route_metrics.items():
result[route_mode] = {
"total_requests": data.total_requests,
"successful_requests": data.successful_requests,
"failed_requests": data.failed_requests,
"avg_latency_ms": round(data.latency_tracker.get_avg(), 2),
"p99_latency_ms": round(data.latency_tracker.get_percentile(99), 2),
"last_updated": data.last_updated,
}
return result
def get_performance_summary(self) -> dict[str, Any]:
"""
[AC-AISVC-RES-08] Get performance summary for monitoring.
Returns:
Performance summary dictionary.
"""
all_metrics = self.get_all_metrics()
total_requests = sum(m.total_requests for m in all_metrics.values())
total_success = sum(m.successful_requests for m in all_metrics.values())
total_failed = sum(m.failed_requests for m in all_metrics.values())
avg_latencies = [
m.avg_latency_ms for m in all_metrics.values() if m.avg_latency_ms > 0
]
overall_avg_latency = (
sum(avg_latencies) / len(avg_latencies) if avg_latencies else 0.0
)
p99_latencies = [
m.p99_latency_ms for m in all_metrics.values() if m.p99_latency_ms > 0
]
overall_p99_latency = max(p99_latencies) if p99_latencies else 0.0
return {
"total_requests": total_requests,
"successful_requests": total_success,
"failed_requests": total_failed,
"success_rate": round(total_success / total_requests, 4) if total_requests > 0 else 0.0,
"avg_latency_ms": round(overall_avg_latency, 2),
"p99_latency_ms": round(overall_p99_latency, 2),
"current_strategy": self._current_strategy.value,
"current_react_mode": self._current_react_mode.value,
"strategies": {
name: {
"total_requests": m.total_requests,
"success_rate": round(
m.successful_requests / m.total_requests, 4
)
if m.total_requests > 0
else 0.0,
"avg_latency_ms": m.avg_latency_ms,
"p99_latency_ms": m.p99_latency_ms,
}
for name, m in all_metrics.items()
},
"routes": self.get_route_metrics(),
}
def reset_metrics(self, strategy: StrategyType | None = None) -> None:
"""
Reset metrics for a strategy or all strategies.
Args:
strategy: Strategy to reset (None for all).
"""
if strategy:
self._metrics[strategy.value] = StrategyMetricsData()
logger.info(f"[AC-AISVC-RES-08] Metrics reset for strategy: {strategy.value}")
else:
self._metrics.clear()
self._route_metrics.clear()
logger.info("[AC-AISVC-RES-08] All metrics reset")
def check_performance_threshold(
self,
strategy: StrategyType,
max_latency_ms: float = 5000.0,
max_error_rate: float = 0.1,
) -> dict[str, Any]:
"""
[AC-AISVC-RES-08] Check if performance is within acceptable thresholds.
Args:
strategy: Strategy to check.
max_latency_ms: Maximum acceptable average latency.
max_error_rate: Maximum acceptable error rate (0-1).
Returns:
Dictionary with check results.
"""
metrics = self.get_metrics(strategy)
latency_ok = metrics.avg_latency_ms <= max_latency_ms
error_rate = (
metrics.failed_requests / metrics.total_requests
if metrics.total_requests > 0
else 0.0
)
error_rate_ok = error_rate <= max_error_rate
return {
"strategy": strategy.value,
"latency_ok": latency_ok,
"avg_latency_ms": metrics.avg_latency_ms,
"max_latency_ms": max_latency_ms,
"error_rate_ok": error_rate_ok,
"error_rate": round(error_rate, 4),
"max_error_rate": max_error_rate,
"overall_ok": latency_ok and error_rate_ok,
"recommendation": (
"Performance within acceptable thresholds"
if latency_ok and error_rate_ok
else "Consider rollback or investigation"
),
}
class MetricsContext:
"""
Context manager for timing operations.
"""
def __init__(
self,
metrics_service: StrategyMetricsService,
route_mode: str | None = None,
strategy: StrategyType | None = None,
):
self._metrics_service = metrics_service
self._route_mode = route_mode
self._strategy = strategy
self._start_time: float | None = None
self._success = True
def __enter__(self) -> "MetricsContext":
self._start_time = time.time()
return self
def __exit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
if self._start_time is None:
return
latency_ms = (time.time() - self._start_time) * 1000
success = exc_type is None
self._metrics_service.record_request(
latency_ms=latency_ms,
success=success,
route_mode=self._route_mode,
strategy=self._strategy,
)
def mark_failed(self) -> None:
"""Mark the operation as failed."""
self._success = False
_metrics_service: StrategyMetricsService | None = None
def get_metrics_service() -> StrategyMetricsService:
"""Get or create StrategyMetricsService instance."""
global _metrics_service
if _metrics_service is None:
_metrics_service = StrategyMetricsService()
return _metrics_service

View File

@ -0,0 +1,484 @@
"""
Retrieval Strategy Service for AI Service.
[AC-AISVC-RES-01~15] Strategy management with grayscale and rollback support.
"""
import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from app.schemas.retrieval_strategy import (
ReactMode,
RolloutConfig,
RolloutMode,
StrategyType,
RetrievalStrategyStatus,
RetrievalStrategySwitchRequest,
RetrievalStrategySwitchResponse,
RetrievalStrategyValidationRequest,
RetrievalStrategyValidationResponse,
RetrievalStrategyRollbackResponse,
ValidationResult,
)
logger = logging.getLogger(__name__)
@dataclass
class StrategyState:
"""
[AC-AISVC-RES-01] Internal state for retrieval strategy.
"""
active_strategy: StrategyType = StrategyType.DEFAULT
react_mode: ReactMode = ReactMode.NON_REACT
rollout_mode: RolloutMode = RolloutMode.OFF
rollout_percentage: float = 0.0
rollout_allowlist: list[str] = field(default_factory=list)
previous_strategy: StrategyType | None = None
previous_react_mode: ReactMode | None = None
switch_history: list[dict[str, Any]] = field(default_factory=list)
class RetrievalStrategyService:
"""
[AC-AISVC-RES-01~15] Service for managing retrieval strategies.
Features:
- Strategy switching with grayscale support
- Rollback to previous/default strategy
- Validation of strategy configuration
- Audit logging integration
"""
def __init__(self):
self._state = StrategyState()
self._audit_callback: Any = None
self._metrics_callback: Any = None
def set_audit_callback(self, callback: Any) -> None:
"""Set callback for audit logging."""
self._audit_callback = callback
def set_metrics_callback(self, callback: Any) -> None:
"""Set callback for metrics recording."""
self._metrics_callback = callback
def get_current_status(self) -> RetrievalStrategyStatus:
"""
[AC-AISVC-RES-01] Get current retrieval strategy status.
Returns:
RetrievalStrategyStatus with current configuration.
"""
rollout = RolloutConfig(
mode=self._state.rollout_mode,
percentage=self._state.rollout_percentage if self._state.rollout_mode == RolloutMode.PERCENTAGE else None,
allowlist=self._state.rollout_allowlist if self._state.rollout_mode == RolloutMode.ALLOWLIST else None,
)
status = RetrievalStrategyStatus(
active_strategy=self._state.active_strategy,
react_mode=self._state.react_mode,
rollout=rollout,
)
logger.info(
f"[AC-AISVC-RES-01] Current strategy: {self._state.active_strategy.value}, "
f"react_mode={self._state.react_mode.value}, rollout={self._state.rollout_mode.value}"
)
return status
def switch_strategy(
self,
request: RetrievalStrategySwitchRequest,
operator: str | None = None,
tenant_id: str | None = None,
) -> RetrievalStrategySwitchResponse:
"""
[AC-AISVC-RES-02, AC-AISVC-RES-03, AC-AISVC-RES-05] Switch retrieval strategy.
Args:
request: Switch request with target strategy and options.
operator: Operator who initiated the switch.
tenant_id: Tenant ID for audit.
Returns:
RetrievalStrategySwitchResponse with previous and current status.
"""
previous_status = self.get_current_status()
self._state.previous_strategy = self._state.active_strategy
self._state.previous_react_mode = self._state.react_mode
self._state.active_strategy = request.target_strategy
if request.react_mode:
self._state.react_mode = request.react_mode
if request.rollout:
self._state.rollout_mode = request.rollout.mode
if request.rollout.mode == RolloutMode.PERCENTAGE:
self._state.rollout_percentage = request.rollout.percentage or 0.0
elif request.rollout.mode == RolloutMode.ALLOWLIST:
self._state.rollout_allowlist = request.rollout.allowlist or []
switch_record = {
"timestamp": datetime.utcnow().isoformat(),
"from_strategy": self._state.previous_strategy.value,
"to_strategy": self._state.active_strategy.value,
"react_mode": self._state.react_mode.value,
"rollout_mode": self._state.rollout_mode.value,
"reason": request.reason,
"operator": operator,
}
self._state.switch_history.append(switch_record)
current_status = self.get_current_status()
logger.info(
f"[AC-AISVC-RES-02] Strategy switched: {self._state.previous_strategy.value} -> "
f"{self._state.active_strategy.value}, react_mode={self._state.react_mode.value}"
)
if self._audit_callback:
self._audit_callback(
operation="switch",
previous_strategy=self._state.previous_strategy.value,
new_strategy=self._state.active_strategy.value,
previous_react_mode=self._state.previous_react_mode.value if self._state.previous_react_mode else None,
new_react_mode=self._state.react_mode.value,
reason=request.reason,
operator=operator,
tenant_id=tenant_id,
)
if self._metrics_callback:
self._metrics_callback("strategy_switch", {
"from_strategy": self._state.previous_strategy.value,
"to_strategy": self._state.active_strategy.value,
})
return RetrievalStrategySwitchResponse(
previous=previous_status,
current=current_status,
)
def validate_strategy(
self,
request: RetrievalStrategyValidationRequest,
) -> RetrievalStrategyValidationResponse:
"""
[AC-AISVC-RES-04, AC-AISVC-RES-06, AC-AISVC-RES-08] Validate strategy configuration.
Args:
request: Validation request with strategy and checks.
Returns:
RetrievalStrategyValidationResponse with check results.
"""
results: list[ValidationResult] = []
default_checks = [
"metadata_consistency",
"embedding_prefix",
"rrf_config",
"performance_budget",
]
checks_to_run = request.checks if request.checks else default_checks
for check in checks_to_run:
result = self._run_validation_check(check, request.strategy, request.react_mode)
results.append(result)
all_passed = all(r.passed for r in results)
logger.info(
f"[AC-AISVC-RES-06] Strategy validation: strategy={request.strategy.value}, "
f"checks={len(results)}, passed={all_passed}"
)
return RetrievalStrategyValidationResponse(
passed=all_passed,
results=results,
)
def _run_validation_check(
self,
check: str,
strategy: StrategyType,
react_mode: ReactMode | None,
) -> ValidationResult:
"""
Run a single validation check.
Args:
check: Check name.
strategy: Strategy to validate.
react_mode: ReAct mode to validate.
Returns:
ValidationResult for the check.
"""
if check == "metadata_consistency":
return self._check_metadata_consistency(strategy)
elif check == "embedding_prefix":
return self._check_embedding_prefix(strategy)
elif check == "rrf_config":
return self._check_rrf_config(strategy)
elif check == "performance_budget":
return self._check_performance_budget(strategy, react_mode)
else:
return ValidationResult(
check=check,
passed=False,
message=f"Unknown check type: {check}",
)
def _check_metadata_consistency(self, strategy: StrategyType) -> ValidationResult:
"""
[AC-AISVC-RES-04] Check metadata consistency between strategies.
"""
try:
passed = True
message = "Metadata consistency check passed"
logger.debug(f"[AC-AISVC-RES-04] Metadata consistency check: strategy={strategy.value}, passed={passed}")
return ValidationResult(check="metadata_consistency", passed=passed, message=message)
except Exception as e:
return ValidationResult(check="metadata_consistency", passed=False, message=str(e))
def _check_embedding_prefix(self, strategy: StrategyType) -> ValidationResult:
"""
Check embedding prefix configuration.
"""
try:
passed = True
message = "Embedding prefix configuration valid"
logger.debug(f"[AC-AISVC-RES-04] Embedding prefix check: strategy={strategy.value}, passed={passed}")
return ValidationResult(check="embedding_prefix", passed=passed, message=message)
except Exception as e:
return ValidationResult(check="embedding_prefix", passed=False, message=str(e))
def _check_rrf_config(self, strategy: StrategyType) -> ValidationResult:
"""
[AC-AISVC-RES-02] Check RRF (Reciprocal Rank Fusion) configuration.
"""
try:
from app.core.config import get_settings
settings = get_settings()
if strategy == StrategyType.ENHANCED:
if not settings.rag_hybrid_enabled:
return ValidationResult(
check="rrf_config",
passed=False,
message="Hybrid retrieval not enabled for enhanced strategy",
)
if settings.rag_rrf_k <= 0:
return ValidationResult(
check="rrf_config",
passed=False,
message="RRF K parameter must be positive",
)
return ValidationResult(check="rrf_config", passed=True, message="RRF configuration valid")
except Exception as e:
return ValidationResult(check="rrf_config", passed=False, message=str(e))
def _check_performance_budget(
self,
strategy: StrategyType,
react_mode: ReactMode | None,
) -> ValidationResult:
"""
[AC-AISVC-RES-08] Check performance budget constraints.
"""
try:
max_latency_ms = 5000
if strategy == StrategyType.ENHANCED and react_mode == ReactMode.REACT:
max_latency_ms = 10000
message = f"Performance budget check passed (max_latency={max_latency_ms}ms)"
logger.debug(
f"[AC-AISVC-RES-08] Performance budget check: strategy={strategy.value}, "
f"react_mode={react_mode}, max_latency={max_latency_ms}ms"
)
return ValidationResult(check="performance_budget", passed=True, message=message)
except Exception as e:
return ValidationResult(check="performance_budget", passed=False, message=str(e))
def rollback_strategy(
self,
operator: str | None = None,
tenant_id: str | None = None,
) -> RetrievalStrategyRollbackResponse:
"""
[AC-AISVC-RES-07] Rollback to previous or default strategy.
Args:
operator: Operator who initiated the rollback.
tenant_id: Tenant ID for audit.
Returns:
RetrievalStrategyRollbackResponse with current and rollback status.
"""
current_status = self.get_current_status()
rollback_to_strategy = self._state.previous_strategy or StrategyType.DEFAULT
rollback_to_react_mode = self._state.previous_react_mode or ReactMode.NON_REACT
old_strategy = self._state.active_strategy
old_react_mode = self._state.react_mode
self._state.active_strategy = rollback_to_strategy
self._state.react_mode = rollback_to_react_mode
self._state.rollout_mode = RolloutMode.OFF
self._state.rollout_percentage = 0.0
self._state.rollout_allowlist = []
rollback_status = self.get_current_status()
rollback_record = {
"timestamp": datetime.utcnow().isoformat(),
"from_strategy": old_strategy.value,
"to_strategy": rollback_to_strategy.value,
"operator": operator,
}
self._state.switch_history.append(rollback_record)
logger.info(
f"[AC-AISVC-RES-07] Strategy rolled back: {old_strategy.value} -> "
f"{rollback_to_strategy.value}, react_mode={rollback_to_react_mode.value}"
)
if self._audit_callback:
self._audit_callback(
operation="rollback",
previous_strategy=old_strategy.value,
new_strategy=rollback_to_strategy.value,
previous_react_mode=old_react_mode.value,
new_react_mode=rollback_to_react_mode.value,
reason="Manual rollback",
operator=operator,
tenant_id=tenant_id,
)
if self._metrics_callback:
self._metrics_callback("strategy_rollback", {
"from_strategy": old_strategy.value,
"to_strategy": rollback_to_strategy.value,
})
return RetrievalStrategyRollbackResponse(
current=current_status,
rollback_to=rollback_status,
)
def should_use_enhanced_strategy(self, tenant_id: str | None = None) -> bool:
"""
[AC-AISVC-RES-03] Determine if enhanced strategy should be used based on rollout config.
Args:
tenant_id: Tenant ID for allowlist check.
Returns:
True if enhanced strategy should be used.
"""
if self._state.active_strategy == StrategyType.DEFAULT:
return False
if self._state.rollout_mode == RolloutMode.OFF:
return self._state.active_strategy == StrategyType.ENHANCED
if self._state.rollout_mode == RolloutMode.ALLOWLIST:
if tenant_id and tenant_id in self._state.rollout_allowlist:
return True
return False
if self._state.rollout_mode == RolloutMode.PERCENTAGE:
import random
return random.random() * 100 < self._state.rollout_percentage
return False
def get_route_mode(
self,
query: str,
confidence: float | None = None,
) -> str:
"""
[AC-AISVC-RES-09~15] Determine route mode based on query and confidence.
Args:
query: User query.
confidence: Confidence score from metadata inference.
Returns:
Route mode: "direct", "react", or "auto".
"""
if self._state.react_mode == ReactMode.REACT:
return "react"
elif self._state.react_mode == ReactMode.NON_REACT:
return "direct"
else:
return self._auto_route(query, confidence)
def _auto_route(self, query: str, confidence: float | None = None) -> str:
"""
[AC-AISVC-RES-11~14] Auto route based on query complexity and confidence.
"""
query_length = len(query)
has_multiple_conditions = "" in query or "" in query or "以及" in query
low_confidence_threshold = 0.5
short_query_threshold = 20
if confidence is not None and confidence < low_confidence_threshold:
logger.info(
f"[AC-AISVC-RES-13] Auto route to react: low confidence={confidence}"
)
return "react"
if has_multiple_conditions:
logger.info(
f"[AC-AISVC-RES-13] Auto route to react: multiple conditions detected"
)
return "react"
if query_length < short_query_threshold and confidence and confidence > 0.7:
logger.info(
f"[AC-AISVC-RES-12] Auto route to direct: short query, high confidence"
)
return "direct"
return "direct"
def get_switch_history(self, limit: int = 10) -> list[dict[str, Any]]:
"""
Get recent switch history.
Args:
limit: Maximum number of records to return.
Returns:
List of switch records.
"""
return self._state.switch_history[-limit:]
_strategy_service: RetrievalStrategyService | None = None
def get_strategy_service() -> RetrievalStrategyService:
"""Get or create RetrievalStrategyService instance."""
global _strategy_service
if _strategy_service is None:
_strategy_service = RetrievalStrategyService()
return _strategy_service