350 lines
13 KiB
Python
350 lines
13 KiB
Python
"""
|
||
Retrieval Strategy API Endpoints.
|
||
[AC-AISVC-RES-01~15] 策略管理 API。
|
||
|
||
Endpoints:
|
||
- GET /strategy/retrieval/current - 获取当前策略状态
|
||
- POST /strategy/retrieval/switch - 切换策略
|
||
- POST /strategy/retrieval/validate - 验证策略配置
|
||
- POST /strategy/retrieval/rollback - 回退策略
|
||
"""
|
||
|
||
import logging
|
||
from typing import Any
|
||
|
||
from fastapi import APIRouter, HTTPException, status
|
||
from pydantic import BaseModel, Field
|
||
|
||
from app.services.retrieval.strategy.config import (
|
||
GrayscaleConfig,
|
||
ModeRouterConfig,
|
||
PipelineConfig,
|
||
RetrievalStrategyConfig,
|
||
RerankerConfig,
|
||
RuntimeMode,
|
||
StrategyType,
|
||
)
|
||
from app.services.retrieval.strategy.rollback_manager import (
|
||
RollbackManager,
|
||
RollbackResult,
|
||
RollbackTrigger,
|
||
get_rollback_manager,
|
||
)
|
||
from app.services.retrieval.strategy.strategy_router import (
|
||
StrategyRouter,
|
||
get_strategy_router,
|
||
set_strategy_router,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
router = APIRouter(prefix="/strategy/retrieval", tags=["strategy"])
|
||
|
||
|
||
class GrayscaleConfigSchema(BaseModel):
|
||
"""灰度配置 Schema。"""
|
||
enabled: bool = False
|
||
percentage: float = Field(default=0.0, ge=0.0, le=100.0)
|
||
allowlist: list[str] = Field(default_factory=list)
|
||
|
||
|
||
class RerankerConfigSchema(BaseModel):
|
||
"""重排器配置 Schema。"""
|
||
enabled: bool = False
|
||
model: str = "cross-encoder"
|
||
top_k_after_rerank: int = 5
|
||
min_score_threshold: float = 0.3
|
||
|
||
|
||
class ModeRouterConfigSchema(BaseModel):
|
||
"""模式路由配置 Schema。"""
|
||
runtime_mode: str = "direct"
|
||
react_trigger_confidence_threshold: float = 0.6
|
||
react_trigger_complexity_score: float = 0.5
|
||
react_max_steps: int = 5
|
||
direct_fallback_on_low_confidence: bool = True
|
||
|
||
|
||
class PipelineConfigSchema(BaseModel):
|
||
"""Pipeline 配置 Schema。"""
|
||
top_k: int = 5
|
||
score_threshold: float = 0.01
|
||
min_hits: int = 1
|
||
two_stage_enabled: bool = True
|
||
|
||
|
||
class StrategyStatusResponse(BaseModel):
|
||
"""策略状态响应。"""
|
||
active_strategy: str
|
||
grayscale: GrayscaleConfigSchema
|
||
pipeline: PipelineConfigSchema
|
||
reranker: RerankerConfigSchema
|
||
mode_router: ModeRouterConfigSchema
|
||
performance_thresholds: dict[str, float] = Field(default_factory=dict)
|
||
|
||
|
||
class SwitchRequest(BaseModel):
|
||
"""策略切换请求。"""
|
||
active_strategy: str | None = None
|
||
grayscale: GrayscaleConfigSchema | None = None
|
||
mode_router: ModeRouterConfigSchema | None = None
|
||
reranker: RerankerConfigSchema | None = None
|
||
|
||
|
||
class SwitchResponse(BaseModel):
|
||
"""策略切换响应。"""
|
||
success: bool
|
||
previous_strategy: str
|
||
current_strategy: str
|
||
message: str
|
||
|
||
|
||
class ValidationRequest(BaseModel):
|
||
"""策略验证请求。"""
|
||
config: dict[str, Any] = Field(default_factory=dict)
|
||
|
||
|
||
class ValidationResponse(BaseModel):
|
||
"""策略验证响应。"""
|
||
valid: bool
|
||
errors: list[str] = Field(default_factory=list)
|
||
warnings: list[str] = Field(default_factory=list)
|
||
config_summary: dict[str, Any] = Field(default_factory=dict)
|
||
|
||
|
||
class RollbackResponse(BaseModel):
|
||
"""策略回退响应。"""
|
||
success: bool
|
||
previous_strategy: str
|
||
current_strategy: str
|
||
trigger: str
|
||
reason: str
|
||
audit_log_id: str | None = None
|
||
|
||
|
||
@router.get(
|
||
"/current",
|
||
response_model=StrategyStatusResponse,
|
||
summary="获取当前策略状态",
|
||
description="【AC-AISVC-RES-01】 获取当前活跃的检索策略配置状态。",
|
||
)
|
||
async def get_current_strategy() -> StrategyStatusResponse:
|
||
"""
|
||
【AC-AISVC-RES-01】 获取当前策略状态。
|
||
"""
|
||
strategy_router = get_strategy_router()
|
||
config = strategy_router.get_config()
|
||
|
||
return StrategyStatusResponse(
|
||
active_strategy=config.active_strategy.value,
|
||
grayscale=GrayscaleConfigSchema(
|
||
enabled=config.grayscale.enabled,
|
||
percentage=config.grayscale.percentage,
|
||
allowlist=config.grayscale.allowlist,
|
||
),
|
||
pipeline=PipelineConfigSchema(
|
||
top_k=config.pipeline.top_k,
|
||
score_threshold=config.pipeline.score_threshold,
|
||
min_hits=config.pipeline.min_hits,
|
||
two_stage_enabled=config.pipeline.two_stage_enabled,
|
||
),
|
||
reranker=RerankerConfigSchema(
|
||
enabled=config.reranker.enabled,
|
||
model=config.reranker.model,
|
||
top_k_after_rerank=config.reranker.top_k_after_rerank,
|
||
min_score_threshold=config.reranker.min_score_threshold,
|
||
),
|
||
mode_router=ModeRouterConfigSchema(
|
||
runtime_mode=config.mode_router.runtime_mode.value,
|
||
react_trigger_confidence_threshold=config.mode_router.react_trigger_confidence_threshold,
|
||
react_trigger_complexity_score=config.mode_router.react_trigger_complexity_score,
|
||
react_max_steps=config.mode_router.react_max_steps,
|
||
direct_fallback_on_low_confidence=config.mode_router.direct_fallback_on_low_confidence,
|
||
),
|
||
performance_thresholds=config.performance_thresholds,
|
||
)
|
||
|
||
|
||
@router.post(
|
||
"/switch",
|
||
response_model=SwitchResponse,
|
||
summary="切换策略",
|
||
description="【AC-AISVC-RES-02, AC-AISVC-RES-03】 切换检索策略, 支持灰度发布配置。",
|
||
)
|
||
async def switch_strategy(request: SwitchRequest) -> SwitchResponse:
|
||
"""
|
||
【AC-AISVC-RES-02, AC-AISVC-RES-03】 切换策略。
|
||
|
||
支持灰度发布配置(percentage/allowlist)。
|
||
"""
|
||
strategy_router = get_strategy_router()
|
||
current_config = strategy_router.get_config()
|
||
previous_strategy = current_config.active_strategy.value
|
||
|
||
try:
|
||
new_active_strategy = StrategyType(request.active_strategy) if request.active_strategy else current_config.active_strategy
|
||
except ValueError as e:
|
||
raise HTTPException(
|
||
status_code=400,
|
||
detail=f"Invalid strategy type: {request.active_strategy}. Valid values are: default, enhanced",
|
||
)
|
||
|
||
new_config = RetrievalStrategyConfig(
|
||
active_strategy=new_active_strategy,
|
||
grayscale=GrayscaleConfig(
|
||
enabled=request.grayscale.enabled if request.grayscale else current_config.grayscale.enabled,
|
||
percentage=request.grayscale.percentage if request.grayscale else current_config.grayscale.percentage,
|
||
allowlist=request.grayscale.allowlist if request.grayscale else current_config.grayscale.allowlist,
|
||
),
|
||
mode_router=ModeRouterConfig(
|
||
runtime_mode=RuntimeMode(request.mode_router.runtime_mode) if request.mode_router and request.mode_router.runtime_mode else current_config.mode_router.runtime_mode,
|
||
react_trigger_confidence_threshold=request.mode_router.react_trigger_confidence_threshold if request.mode_router else current_config.mode_router.react_trigger_confidence_threshold,
|
||
react_trigger_complexity_score=request.mode_router.react_trigger_complexity_score if request.mode_router else current_config.mode_router.react_trigger_complexity_score,
|
||
react_max_steps=request.mode_router.react_max_steps if request.mode_router else current_config.mode_router.react_max_steps,
|
||
direct_fallback_on_low_confidence=request.mode_router.direct_fallback_on_low_confidence if request.mode_router else current_config.mode_router.direct_fallback_on_low_confidence,
|
||
) if request.mode_router else current_config.mode_router,
|
||
reranker=RerankerConfig(
|
||
enabled=request.reranker.enabled if request.reranker else current_config.reranker.enabled,
|
||
model=request.reranker.model if request.reranker else current_config.reranker.model,
|
||
top_k_after_rerank=request.reranker.top_k_after_rerank if request.reranker else current_config.reranker.top_k_after_rerank,
|
||
min_score_threshold=request.reranker.min_score_threshold if request.reranker else current_config.reranker.min_score_threshold,
|
||
) if request.reranker else current_config.reranker,
|
||
pipeline=current_config.pipeline,
|
||
metadata_inference=current_config.metadata_inference,
|
||
performance_thresholds=current_config.performance_thresholds,
|
||
)
|
||
|
||
strategy_router.update_config(new_config)
|
||
|
||
logger.info(
|
||
f"[AC-AISVC-RES-02] Strategy switched: {previous_strategy} -> {new_config.active_strategy.value}"
|
||
)
|
||
|
||
return SwitchResponse(
|
||
success=True,
|
||
previous_strategy=previous_strategy,
|
||
current_strategy=new_config.active_strategy.value,
|
||
message=f"Strategy switched from {previous_strategy} to {new_config.active_strategy.value}",
|
||
)
|
||
|
||
|
||
@router.post(
|
||
"/validate",
|
||
response_model=ValidationResponse,
|
||
summary="验证策略配置",
|
||
description="【AC-AISVC-RES-06, AC-AISVC-RES-08】 验证策略配置的完整性与一致性。",
|
||
)
|
||
async def validate_strategy(request: ValidationRequest) -> ValidationResponse:
|
||
"""
|
||
【AC-AISVC-RES-06, AC-AISVC-RES-08】 验证策略配置。
|
||
"""
|
||
errors: list[str] = []
|
||
warnings: list[str] = []
|
||
config = request.config
|
||
|
||
if "active_strategy" in config:
|
||
if config["active_strategy"] not in ["default", "enhanced"]:
|
||
errors.append(f"Invalid active_strategy: {config['active_strategy']}")
|
||
|
||
if "grayscale" in config:
|
||
grayscale = config["grayscale"]
|
||
if grayscale.get("percentage", 1.0) is not None:
|
||
if not (0 <= grayscale["percentage"] <= 100):
|
||
errors.append(f"Invalid grayscale percentage: {grayscale['percentage']}")
|
||
|
||
if "mode_router" in config:
|
||
mode_router = config["mode_router"]
|
||
if "runtime_mode" in mode_router:
|
||
if mode_router["runtime_mode"] not in ["direct", "react", "auto"]:
|
||
errors.append(f"Invalid runtime_mode: {mode_router['runtime_mode']}")
|
||
|
||
if "reranker" in config:
|
||
reranker = config["reranker"]
|
||
if reranker.get("enabled") and reranker.get("top_k_after_rerank", 1) > 20:
|
||
warnings.append(f"top_k_after_rerank should be between 1 and 20, current: {reranker['top_k_after_rerank']}")
|
||
|
||
if "performance_thresholds" in config:
|
||
thresholds = config["performance_thresholds"]
|
||
if thresholds.get("max_latency_ms", 0) is not None and thresholds["max_latency_ms"] < 100:
|
||
warnings.append(f"max_latency_ms seems too low: {thresholds['max_latency_ms']}ms")
|
||
|
||
return ValidationResponse(
|
||
valid=len(errors) == 0,
|
||
errors=errors,
|
||
warnings=warnings,
|
||
config_summary={
|
||
"active_strategy": config.get("active_strategy", "default"),
|
||
"grayscale_enabled": config.get("grayscale", {}).get("enabled", False),
|
||
"runtime_mode": config.get("mode_router", {}).get("runtime_mode", "direct"),
|
||
"reranker_enabled": config.get("reranker", {}).get("enabled", False),
|
||
"performance_thresholds": config.get("performance_thresholds", {}),
|
||
},
|
||
)
|
||
|
||
|
||
@router.post(
|
||
"/rollback",
|
||
response_model=RollbackResponse,
|
||
summary="回退策略",
|
||
description="【AC-AISVC-RES-07】 回退到默认策略。",
|
||
)
|
||
async def rollback_strategy(
|
||
trigger: str = "manual",
|
||
reason: str = "",
|
||
) -> RollbackResponse:
|
||
"""
|
||
【AC-AISVC-RES-07】 回退策略。
|
||
|
||
支持手动触发和自动触发(性能退化、异常)。
|
||
"""
|
||
strategy_router = get_strategy_router()
|
||
current_config = strategy_router.get_config()
|
||
previous_strategy = current_config.active_strategy
|
||
|
||
if previous_strategy == StrategyType.DEFAULT:
|
||
return RollbackResponse(
|
||
success=False,
|
||
previous_strategy=previous_strategy.value,
|
||
current_strategy=previous_strategy.value,
|
||
trigger=trigger,
|
||
reason="Already on default strategy",
|
||
audit_log_id=None,
|
||
)
|
||
|
||
new_config = RetrievalStrategyConfig(
|
||
active_strategy=StrategyType.DEFAULT,
|
||
grayscale=current_config.grayscale,
|
||
mode_router=current_config.mode_router,
|
||
reranker=current_config.reranker,
|
||
pipeline=current_config.pipeline,
|
||
metadata_inference=current_config.metadata_inference,
|
||
performance_thresholds=current_config.performance_thresholds,
|
||
)
|
||
strategy_router.update_config(new_config)
|
||
|
||
rollback_manager = get_rollback_manager()
|
||
rollback_manager.update_config(new_config)
|
||
|
||
audit_log = rollback_manager.record_audit(
|
||
action="rollback",
|
||
details={
|
||
"from_strategy": previous_strategy.value,
|
||
"to_strategy": StrategyType.DEFAULT.value,
|
||
"trigger": trigger,
|
||
"reason": reason or "Manual rollback",
|
||
},
|
||
)
|
||
|
||
logger.info(
|
||
f"[AC-AISVC-RES-07] Strategy rolled back: {previous_strategy.value} -> default"
|
||
)
|
||
|
||
return RollbackResponse(
|
||
success=True,
|
||
previous_strategy=previous_strategy.value,
|
||
current_strategy=StrategyType.DEFAULT.value,
|
||
trigger=trigger,
|
||
reason=reason or "Manual rollback",
|
||
audit_log_id=str(audit_log.timestamp) if audit_log else None,
|
||
)
|