ai-robot-core/ai-service/app/api/admin/retrieval_strategy.py

350 lines
13 KiB
Python
Raw Normal View History

"""
Retrieval Strategy API Endpoints.
[AC-AISVC-RES-01~15] 策略管理 API
Endpoints:
- GET /strategy/retrieval/current - 获取当前策略状态
- POST /strategy/retrieval/switch - 切换策略
- POST /strategy/retrieval/validate - 验证策略配置
- POST /strategy/retrieval/rollback - 回退策略
"""
import logging
from typing import Any
from fastapi import APIRouter, HTTPException, status
from pydantic import BaseModel, Field
from app.services.retrieval.strategy.config import (
GrayscaleConfig,
ModeRouterConfig,
PipelineConfig,
RetrievalStrategyConfig,
RerankerConfig,
RuntimeMode,
StrategyType,
)
from app.services.retrieval.strategy.rollback_manager import (
RollbackManager,
RollbackResult,
RollbackTrigger,
get_rollback_manager,
)
from app.services.retrieval.strategy.strategy_router import (
StrategyRouter,
get_strategy_router,
set_strategy_router,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/strategy/retrieval", tags=["strategy"])
class GrayscaleConfigSchema(BaseModel):
"""灰度配置 Schema。"""
enabled: bool = False
percentage: float = Field(default=0.0, ge=0.0, le=100.0)
allowlist: list[str] = Field(default_factory=list)
class RerankerConfigSchema(BaseModel):
"""重排器配置 Schema。"""
enabled: bool = False
model: str = "cross-encoder"
top_k_after_rerank: int = 5
min_score_threshold: float = 0.3
class ModeRouterConfigSchema(BaseModel):
"""模式路由配置 Schema。"""
runtime_mode: str = "direct"
react_trigger_confidence_threshold: float = 0.6
react_trigger_complexity_score: float = 0.5
react_max_steps: int = 5
direct_fallback_on_low_confidence: bool = True
class PipelineConfigSchema(BaseModel):
"""Pipeline 配置 Schema。"""
top_k: int = 5
score_threshold: float = 0.01
min_hits: int = 1
two_stage_enabled: bool = True
class StrategyStatusResponse(BaseModel):
"""策略状态响应。"""
active_strategy: str
grayscale: GrayscaleConfigSchema
pipeline: PipelineConfigSchema
reranker: RerankerConfigSchema
mode_router: ModeRouterConfigSchema
performance_thresholds: dict[str, float] = Field(default_factory=dict)
class SwitchRequest(BaseModel):
"""策略切换请求。"""
active_strategy: str | None = None
grayscale: GrayscaleConfigSchema | None = None
mode_router: ModeRouterConfigSchema | None = None
reranker: RerankerConfigSchema | None = None
class SwitchResponse(BaseModel):
"""策略切换响应。"""
success: bool
previous_strategy: str
current_strategy: str
message: str
class ValidationRequest(BaseModel):
"""策略验证请求。"""
config: dict[str, Any] = Field(default_factory=dict)
class ValidationResponse(BaseModel):
"""策略验证响应。"""
valid: bool
errors: list[str] = Field(default_factory=list)
warnings: list[str] = Field(default_factory=list)
config_summary: dict[str, Any] = Field(default_factory=dict)
class RollbackResponse(BaseModel):
"""策略回退响应。"""
success: bool
previous_strategy: str
current_strategy: str
trigger: str
reason: str
audit_log_id: str | None = None
@router.get(
"/current",
response_model=StrategyStatusResponse,
summary="获取当前策略状态",
description="【AC-AISVC-RES-01】 获取当前活跃的检索策略配置状态。",
)
async def get_current_strategy() -> StrategyStatusResponse:
"""
AC-AISVC-RES-01 获取当前策略状态
"""
strategy_router = get_strategy_router()
config = strategy_router.get_config()
return StrategyStatusResponse(
active_strategy=config.active_strategy.value,
grayscale=GrayscaleConfigSchema(
enabled=config.grayscale.enabled,
percentage=config.grayscale.percentage,
allowlist=config.grayscale.allowlist,
),
pipeline=PipelineConfigSchema(
top_k=config.pipeline.top_k,
score_threshold=config.pipeline.score_threshold,
min_hits=config.pipeline.min_hits,
two_stage_enabled=config.pipeline.two_stage_enabled,
),
reranker=RerankerConfigSchema(
enabled=config.reranker.enabled,
model=config.reranker.model,
top_k_after_rerank=config.reranker.top_k_after_rerank,
min_score_threshold=config.reranker.min_score_threshold,
),
mode_router=ModeRouterConfigSchema(
runtime_mode=config.mode_router.runtime_mode.value,
react_trigger_confidence_threshold=config.mode_router.react_trigger_confidence_threshold,
react_trigger_complexity_score=config.mode_router.react_trigger_complexity_score,
react_max_steps=config.mode_router.react_max_steps,
direct_fallback_on_low_confidence=config.mode_router.direct_fallback_on_low_confidence,
),
performance_thresholds=config.performance_thresholds,
)
@router.post(
"/switch",
response_model=SwitchResponse,
summary="切换策略",
description="【AC-AISVC-RES-02, AC-AISVC-RES-03】 切换检索策略, 支持灰度发布配置。",
)
async def switch_strategy(request: SwitchRequest) -> SwitchResponse:
"""
AC-AISVC-RES-02, AC-AISVC-RES-03 切换策略
支持灰度发布配置percentage/allowlist
"""
strategy_router = get_strategy_router()
current_config = strategy_router.get_config()
previous_strategy = current_config.active_strategy.value
try:
new_active_strategy = StrategyType(request.active_strategy) if request.active_strategy else current_config.active_strategy
except ValueError as e:
raise HTTPException(
status_code=400,
detail=f"Invalid strategy type: {request.active_strategy}. Valid values are: default, enhanced",
)
new_config = RetrievalStrategyConfig(
active_strategy=new_active_strategy,
grayscale=GrayscaleConfig(
enabled=request.grayscale.enabled if request.grayscale else current_config.grayscale.enabled,
percentage=request.grayscale.percentage if request.grayscale else current_config.grayscale.percentage,
allowlist=request.grayscale.allowlist if request.grayscale else current_config.grayscale.allowlist,
),
mode_router=ModeRouterConfig(
runtime_mode=RuntimeMode(request.mode_router.runtime_mode) if request.mode_router and request.mode_router.runtime_mode else current_config.mode_router.runtime_mode,
react_trigger_confidence_threshold=request.mode_router.react_trigger_confidence_threshold if request.mode_router else current_config.mode_router.react_trigger_confidence_threshold,
react_trigger_complexity_score=request.mode_router.react_trigger_complexity_score if request.mode_router else current_config.mode_router.react_trigger_complexity_score,
react_max_steps=request.mode_router.react_max_steps if request.mode_router else current_config.mode_router.react_max_steps,
direct_fallback_on_low_confidence=request.mode_router.direct_fallback_on_low_confidence if request.mode_router else current_config.mode_router.direct_fallback_on_low_confidence,
) if request.mode_router else current_config.mode_router,
reranker=RerankerConfig(
enabled=request.reranker.enabled if request.reranker else current_config.reranker.enabled,
model=request.reranker.model if request.reranker else current_config.reranker.model,
top_k_after_rerank=request.reranker.top_k_after_rerank if request.reranker else current_config.reranker.top_k_after_rerank,
min_score_threshold=request.reranker.min_score_threshold if request.reranker else current_config.reranker.min_score_threshold,
) if request.reranker else current_config.reranker,
pipeline=current_config.pipeline,
metadata_inference=current_config.metadata_inference,
performance_thresholds=current_config.performance_thresholds,
)
strategy_router.update_config(new_config)
logger.info(
f"[AC-AISVC-RES-02] Strategy switched: {previous_strategy} -> {new_config.active_strategy.value}"
)
return SwitchResponse(
success=True,
previous_strategy=previous_strategy,
current_strategy=new_config.active_strategy.value,
message=f"Strategy switched from {previous_strategy} to {new_config.active_strategy.value}",
)
@router.post(
"/validate",
response_model=ValidationResponse,
summary="验证策略配置",
description="【AC-AISVC-RES-06, AC-AISVC-RES-08】 验证策略配置的完整性与一致性。",
)
async def validate_strategy(request: ValidationRequest) -> ValidationResponse:
"""
AC-AISVC-RES-06, AC-AISVC-RES-08 验证策略配置
"""
errors: list[str] = []
warnings: list[str] = []
config = request.config
if "active_strategy" in config:
if config["active_strategy"] not in ["default", "enhanced"]:
errors.append(f"Invalid active_strategy: {config['active_strategy']}")
if "grayscale" in config:
grayscale = config["grayscale"]
if grayscale.get("percentage", 1.0) is not None:
if not (0 <= grayscale["percentage"] <= 100):
errors.append(f"Invalid grayscale percentage: {grayscale['percentage']}")
if "mode_router" in config:
mode_router = config["mode_router"]
if "runtime_mode" in mode_router:
if mode_router["runtime_mode"] not in ["direct", "react", "auto"]:
errors.append(f"Invalid runtime_mode: {mode_router['runtime_mode']}")
if "reranker" in config:
reranker = config["reranker"]
if reranker.get("enabled") and reranker.get("top_k_after_rerank", 1) > 20:
warnings.append(f"top_k_after_rerank should be between 1 and 20, current: {reranker['top_k_after_rerank']}")
if "performance_thresholds" in config:
thresholds = config["performance_thresholds"]
if thresholds.get("max_latency_ms", 0) is not None and thresholds["max_latency_ms"] < 100:
warnings.append(f"max_latency_ms seems too low: {thresholds['max_latency_ms']}ms")
return ValidationResponse(
valid=len(errors) == 0,
errors=errors,
warnings=warnings,
config_summary={
"active_strategy": config.get("active_strategy", "default"),
"grayscale_enabled": config.get("grayscale", {}).get("enabled", False),
"runtime_mode": config.get("mode_router", {}).get("runtime_mode", "direct"),
"reranker_enabled": config.get("reranker", {}).get("enabled", False),
"performance_thresholds": config.get("performance_thresholds", {}),
},
)
@router.post(
"/rollback",
response_model=RollbackResponse,
summary="回退策略",
description="【AC-AISVC-RES-07】 回退到默认策略。",
)
async def rollback_strategy(
trigger: str = "manual",
reason: str = "",
) -> RollbackResponse:
"""
AC-AISVC-RES-07 回退策略
支持手动触发和自动触发性能退化异常
"""
strategy_router = get_strategy_router()
current_config = strategy_router.get_config()
previous_strategy = current_config.active_strategy
if previous_strategy == StrategyType.DEFAULT:
return RollbackResponse(
success=False,
previous_strategy=previous_strategy.value,
current_strategy=previous_strategy.value,
trigger=trigger,
reason="Already on default strategy",
audit_log_id=None,
)
new_config = RetrievalStrategyConfig(
active_strategy=StrategyType.DEFAULT,
grayscale=current_config.grayscale,
mode_router=current_config.mode_router,
reranker=current_config.reranker,
pipeline=current_config.pipeline,
metadata_inference=current_config.metadata_inference,
performance_thresholds=current_config.performance_thresholds,
)
strategy_router.update_config(new_config)
rollback_manager = get_rollback_manager()
rollback_manager.update_config(new_config)
audit_log = rollback_manager.record_audit(
action="rollback",
details={
"from_strategy": previous_strategy.value,
"to_strategy": StrategyType.DEFAULT.value,
"trigger": trigger,
"reason": reason or "Manual rollback",
},
)
logger.info(
f"[AC-AISVC-RES-07] Strategy rolled back: {previous_strategy.value} -> default"
)
return RollbackResponse(
success=True,
previous_strategy=previous_strategy.value,
current_strategy=StrategyType.DEFAULT.value,
trigger=trigger,
reason=reason or "Manual rollback",
audit_log_id=str(audit_log.timestamp) if audit_log else None,
)