ai-robot-core/ai-service/app/api/admin/retrieval_strategy.py

350 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Retrieval Strategy API Endpoints.
[AC-AISVC-RES-01~15] 策略管理 API。
Endpoints:
- GET /strategy/retrieval/current - 获取当前策略状态
- POST /strategy/retrieval/switch - 切换策略
- POST /strategy/retrieval/validate - 验证策略配置
- POST /strategy/retrieval/rollback - 回退策略
"""
import logging
from typing import Any
from fastapi import APIRouter, HTTPException, status
from pydantic import BaseModel, Field
from app.services.retrieval.strategy.config import (
GrayscaleConfig,
ModeRouterConfig,
PipelineConfig,
RetrievalStrategyConfig,
RerankerConfig,
RuntimeMode,
StrategyType,
)
from app.services.retrieval.strategy.rollback_manager import (
RollbackManager,
RollbackResult,
RollbackTrigger,
get_rollback_manager,
)
from app.services.retrieval.strategy.strategy_router import (
StrategyRouter,
get_strategy_router,
set_strategy_router,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/strategy/retrieval", tags=["strategy"])
class GrayscaleConfigSchema(BaseModel):
"""灰度配置 Schema。"""
enabled: bool = False
percentage: float = Field(default=0.0, ge=0.0, le=100.0)
allowlist: list[str] = Field(default_factory=list)
class RerankerConfigSchema(BaseModel):
"""重排器配置 Schema。"""
enabled: bool = False
model: str = "cross-encoder"
top_k_after_rerank: int = 5
min_score_threshold: float = 0.3
class ModeRouterConfigSchema(BaseModel):
"""模式路由配置 Schema。"""
runtime_mode: str = "direct"
react_trigger_confidence_threshold: float = 0.6
react_trigger_complexity_score: float = 0.5
react_max_steps: int = 5
direct_fallback_on_low_confidence: bool = True
class PipelineConfigSchema(BaseModel):
"""Pipeline 配置 Schema。"""
top_k: int = 5
score_threshold: float = 0.01
min_hits: int = 1
two_stage_enabled: bool = True
class StrategyStatusResponse(BaseModel):
"""策略状态响应。"""
active_strategy: str
grayscale: GrayscaleConfigSchema
pipeline: PipelineConfigSchema
reranker: RerankerConfigSchema
mode_router: ModeRouterConfigSchema
performance_thresholds: dict[str, float] = Field(default_factory=dict)
class SwitchRequest(BaseModel):
"""策略切换请求。"""
active_strategy: str | None = None
grayscale: GrayscaleConfigSchema | None = None
mode_router: ModeRouterConfigSchema | None = None
reranker: RerankerConfigSchema | None = None
class SwitchResponse(BaseModel):
"""策略切换响应。"""
success: bool
previous_strategy: str
current_strategy: str
message: str
class ValidationRequest(BaseModel):
"""策略验证请求。"""
config: dict[str, Any] = Field(default_factory=dict)
class ValidationResponse(BaseModel):
"""策略验证响应。"""
valid: bool
errors: list[str] = Field(default_factory=list)
warnings: list[str] = Field(default_factory=list)
config_summary: dict[str, Any] = Field(default_factory=dict)
class RollbackResponse(BaseModel):
"""策略回退响应。"""
success: bool
previous_strategy: str
current_strategy: str
trigger: str
reason: str
audit_log_id: str | None = None
@router.get(
"/current",
response_model=StrategyStatusResponse,
summary="获取当前策略状态",
description="【AC-AISVC-RES-01】 获取当前活跃的检索策略配置状态。",
)
async def get_current_strategy() -> StrategyStatusResponse:
"""
【AC-AISVC-RES-01】 获取当前策略状态。
"""
strategy_router = get_strategy_router()
config = strategy_router.get_config()
return StrategyStatusResponse(
active_strategy=config.active_strategy.value,
grayscale=GrayscaleConfigSchema(
enabled=config.grayscale.enabled,
percentage=config.grayscale.percentage,
allowlist=config.grayscale.allowlist,
),
pipeline=PipelineConfigSchema(
top_k=config.pipeline.top_k,
score_threshold=config.pipeline.score_threshold,
min_hits=config.pipeline.min_hits,
two_stage_enabled=config.pipeline.two_stage_enabled,
),
reranker=RerankerConfigSchema(
enabled=config.reranker.enabled,
model=config.reranker.model,
top_k_after_rerank=config.reranker.top_k_after_rerank,
min_score_threshold=config.reranker.min_score_threshold,
),
mode_router=ModeRouterConfigSchema(
runtime_mode=config.mode_router.runtime_mode.value,
react_trigger_confidence_threshold=config.mode_router.react_trigger_confidence_threshold,
react_trigger_complexity_score=config.mode_router.react_trigger_complexity_score,
react_max_steps=config.mode_router.react_max_steps,
direct_fallback_on_low_confidence=config.mode_router.direct_fallback_on_low_confidence,
),
performance_thresholds=config.performance_thresholds,
)
@router.post(
"/switch",
response_model=SwitchResponse,
summary="切换策略",
description="【AC-AISVC-RES-02, AC-AISVC-RES-03】 切换检索策略, 支持灰度发布配置。",
)
async def switch_strategy(request: SwitchRequest) -> SwitchResponse:
"""
【AC-AISVC-RES-02, AC-AISVC-RES-03】 切换策略。
支持灰度发布配置percentage/allowlist
"""
strategy_router = get_strategy_router()
current_config = strategy_router.get_config()
previous_strategy = current_config.active_strategy.value
try:
new_active_strategy = StrategyType(request.active_strategy) if request.active_strategy else current_config.active_strategy
except ValueError as e:
raise HTTPException(
status_code=400,
detail=f"Invalid strategy type: {request.active_strategy}. Valid values are: default, enhanced",
)
new_config = RetrievalStrategyConfig(
active_strategy=new_active_strategy,
grayscale=GrayscaleConfig(
enabled=request.grayscale.enabled if request.grayscale else current_config.grayscale.enabled,
percentage=request.grayscale.percentage if request.grayscale else current_config.grayscale.percentage,
allowlist=request.grayscale.allowlist if request.grayscale else current_config.grayscale.allowlist,
),
mode_router=ModeRouterConfig(
runtime_mode=RuntimeMode(request.mode_router.runtime_mode) if request.mode_router and request.mode_router.runtime_mode else current_config.mode_router.runtime_mode,
react_trigger_confidence_threshold=request.mode_router.react_trigger_confidence_threshold if request.mode_router else current_config.mode_router.react_trigger_confidence_threshold,
react_trigger_complexity_score=request.mode_router.react_trigger_complexity_score if request.mode_router else current_config.mode_router.react_trigger_complexity_score,
react_max_steps=request.mode_router.react_max_steps if request.mode_router else current_config.mode_router.react_max_steps,
direct_fallback_on_low_confidence=request.mode_router.direct_fallback_on_low_confidence if request.mode_router else current_config.mode_router.direct_fallback_on_low_confidence,
) if request.mode_router else current_config.mode_router,
reranker=RerankerConfig(
enabled=request.reranker.enabled if request.reranker else current_config.reranker.enabled,
model=request.reranker.model if request.reranker else current_config.reranker.model,
top_k_after_rerank=request.reranker.top_k_after_rerank if request.reranker else current_config.reranker.top_k_after_rerank,
min_score_threshold=request.reranker.min_score_threshold if request.reranker else current_config.reranker.min_score_threshold,
) if request.reranker else current_config.reranker,
pipeline=current_config.pipeline,
metadata_inference=current_config.metadata_inference,
performance_thresholds=current_config.performance_thresholds,
)
strategy_router.update_config(new_config)
logger.info(
f"[AC-AISVC-RES-02] Strategy switched: {previous_strategy} -> {new_config.active_strategy.value}"
)
return SwitchResponse(
success=True,
previous_strategy=previous_strategy,
current_strategy=new_config.active_strategy.value,
message=f"Strategy switched from {previous_strategy} to {new_config.active_strategy.value}",
)
@router.post(
"/validate",
response_model=ValidationResponse,
summary="验证策略配置",
description="【AC-AISVC-RES-06, AC-AISVC-RES-08】 验证策略配置的完整性与一致性。",
)
async def validate_strategy(request: ValidationRequest) -> ValidationResponse:
"""
【AC-AISVC-RES-06, AC-AISVC-RES-08】 验证策略配置。
"""
errors: list[str] = []
warnings: list[str] = []
config = request.config
if "active_strategy" in config:
if config["active_strategy"] not in ["default", "enhanced"]:
errors.append(f"Invalid active_strategy: {config['active_strategy']}")
if "grayscale" in config:
grayscale = config["grayscale"]
if grayscale.get("percentage", 1.0) is not None:
if not (0 <= grayscale["percentage"] <= 100):
errors.append(f"Invalid grayscale percentage: {grayscale['percentage']}")
if "mode_router" in config:
mode_router = config["mode_router"]
if "runtime_mode" in mode_router:
if mode_router["runtime_mode"] not in ["direct", "react", "auto"]:
errors.append(f"Invalid runtime_mode: {mode_router['runtime_mode']}")
if "reranker" in config:
reranker = config["reranker"]
if reranker.get("enabled") and reranker.get("top_k_after_rerank", 1) > 20:
warnings.append(f"top_k_after_rerank should be between 1 and 20, current: {reranker['top_k_after_rerank']}")
if "performance_thresholds" in config:
thresholds = config["performance_thresholds"]
if thresholds.get("max_latency_ms", 0) is not None and thresholds["max_latency_ms"] < 100:
warnings.append(f"max_latency_ms seems too low: {thresholds['max_latency_ms']}ms")
return ValidationResponse(
valid=len(errors) == 0,
errors=errors,
warnings=warnings,
config_summary={
"active_strategy": config.get("active_strategy", "default"),
"grayscale_enabled": config.get("grayscale", {}).get("enabled", False),
"runtime_mode": config.get("mode_router", {}).get("runtime_mode", "direct"),
"reranker_enabled": config.get("reranker", {}).get("enabled", False),
"performance_thresholds": config.get("performance_thresholds", {}),
},
)
@router.post(
"/rollback",
response_model=RollbackResponse,
summary="回退策略",
description="【AC-AISVC-RES-07】 回退到默认策略。",
)
async def rollback_strategy(
trigger: str = "manual",
reason: str = "",
) -> RollbackResponse:
"""
【AC-AISVC-RES-07】 回退策略。
支持手动触发和自动触发(性能退化、异常)。
"""
strategy_router = get_strategy_router()
current_config = strategy_router.get_config()
previous_strategy = current_config.active_strategy
if previous_strategy == StrategyType.DEFAULT:
return RollbackResponse(
success=False,
previous_strategy=previous_strategy.value,
current_strategy=previous_strategy.value,
trigger=trigger,
reason="Already on default strategy",
audit_log_id=None,
)
new_config = RetrievalStrategyConfig(
active_strategy=StrategyType.DEFAULT,
grayscale=current_config.grayscale,
mode_router=current_config.mode_router,
reranker=current_config.reranker,
pipeline=current_config.pipeline,
metadata_inference=current_config.metadata_inference,
performance_thresholds=current_config.performance_thresholds,
)
strategy_router.update_config(new_config)
rollback_manager = get_rollback_manager()
rollback_manager.update_config(new_config)
audit_log = rollback_manager.record_audit(
action="rollback",
details={
"from_strategy": previous_strategy.value,
"to_strategy": StrategyType.DEFAULT.value,
"trigger": trigger,
"reason": reason or "Manual rollback",
},
)
logger.info(
f"[AC-AISVC-RES-07] Strategy rolled back: {previous_strategy.value} -> default"
)
return RollbackResponse(
success=True,
previous_strategy=previous_strategy.value,
current_strategy=StrategyType.DEFAULT.value,
trigger=trigger,
reason=reason or "Manual rollback",
audit_log_id=str(audit_log.timestamp) if audit_log else None,
)