350 lines
13 KiB
Python
350 lines
13 KiB
Python
|
|
"""
|
|||
|
|
Retrieval Strategy API Endpoints.
|
|||
|
|
[AC-AISVC-RES-01~15] 策略管理 API。
|
|||
|
|
|
|||
|
|
Endpoints:
|
|||
|
|
- GET /strategy/retrieval/current - 获取当前策略状态
|
|||
|
|
- POST /strategy/retrieval/switch - 切换策略
|
|||
|
|
- POST /strategy/retrieval/validate - 验证策略配置
|
|||
|
|
- POST /strategy/retrieval/rollback - 回退策略
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import logging
|
|||
|
|
from typing import Any
|
|||
|
|
|
|||
|
|
from fastapi import APIRouter, HTTPException, status
|
|||
|
|
from pydantic import BaseModel, Field
|
|||
|
|
|
|||
|
|
from app.services.retrieval.strategy.config import (
|
|||
|
|
GrayscaleConfig,
|
|||
|
|
ModeRouterConfig,
|
|||
|
|
PipelineConfig,
|
|||
|
|
RetrievalStrategyConfig,
|
|||
|
|
RerankerConfig,
|
|||
|
|
RuntimeMode,
|
|||
|
|
StrategyType,
|
|||
|
|
)
|
|||
|
|
from app.services.retrieval.strategy.rollback_manager import (
|
|||
|
|
RollbackManager,
|
|||
|
|
RollbackResult,
|
|||
|
|
RollbackTrigger,
|
|||
|
|
get_rollback_manager,
|
|||
|
|
)
|
|||
|
|
from app.services.retrieval.strategy.strategy_router import (
|
|||
|
|
StrategyRouter,
|
|||
|
|
get_strategy_router,
|
|||
|
|
set_strategy_router,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
logger = logging.getLogger(__name__)
|
|||
|
|
|
|||
|
|
router = APIRouter(prefix="/strategy/retrieval", tags=["strategy"])
|
|||
|
|
|
|||
|
|
|
|||
|
|
class GrayscaleConfigSchema(BaseModel):
|
|||
|
|
"""灰度配置 Schema。"""
|
|||
|
|
enabled: bool = False
|
|||
|
|
percentage: float = Field(default=0.0, ge=0.0, le=100.0)
|
|||
|
|
allowlist: list[str] = Field(default_factory=list)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class RerankerConfigSchema(BaseModel):
|
|||
|
|
"""重排器配置 Schema。"""
|
|||
|
|
enabled: bool = False
|
|||
|
|
model: str = "cross-encoder"
|
|||
|
|
top_k_after_rerank: int = 5
|
|||
|
|
min_score_threshold: float = 0.3
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ModeRouterConfigSchema(BaseModel):
|
|||
|
|
"""模式路由配置 Schema。"""
|
|||
|
|
runtime_mode: str = "direct"
|
|||
|
|
react_trigger_confidence_threshold: float = 0.6
|
|||
|
|
react_trigger_complexity_score: float = 0.5
|
|||
|
|
react_max_steps: int = 5
|
|||
|
|
direct_fallback_on_low_confidence: bool = True
|
|||
|
|
|
|||
|
|
|
|||
|
|
class PipelineConfigSchema(BaseModel):
|
|||
|
|
"""Pipeline 配置 Schema。"""
|
|||
|
|
top_k: int = 5
|
|||
|
|
score_threshold: float = 0.01
|
|||
|
|
min_hits: int = 1
|
|||
|
|
two_stage_enabled: bool = True
|
|||
|
|
|
|||
|
|
|
|||
|
|
class StrategyStatusResponse(BaseModel):
|
|||
|
|
"""策略状态响应。"""
|
|||
|
|
active_strategy: str
|
|||
|
|
grayscale: GrayscaleConfigSchema
|
|||
|
|
pipeline: PipelineConfigSchema
|
|||
|
|
reranker: RerankerConfigSchema
|
|||
|
|
mode_router: ModeRouterConfigSchema
|
|||
|
|
performance_thresholds: dict[str, float] = Field(default_factory=dict)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SwitchRequest(BaseModel):
|
|||
|
|
"""策略切换请求。"""
|
|||
|
|
active_strategy: str | None = None
|
|||
|
|
grayscale: GrayscaleConfigSchema | None = None
|
|||
|
|
mode_router: ModeRouterConfigSchema | None = None
|
|||
|
|
reranker: RerankerConfigSchema | None = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
class SwitchResponse(BaseModel):
|
|||
|
|
"""策略切换响应。"""
|
|||
|
|
success: bool
|
|||
|
|
previous_strategy: str
|
|||
|
|
current_strategy: str
|
|||
|
|
message: str
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ValidationRequest(BaseModel):
|
|||
|
|
"""策略验证请求。"""
|
|||
|
|
config: dict[str, Any] = Field(default_factory=dict)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class ValidationResponse(BaseModel):
|
|||
|
|
"""策略验证响应。"""
|
|||
|
|
valid: bool
|
|||
|
|
errors: list[str] = Field(default_factory=list)
|
|||
|
|
warnings: list[str] = Field(default_factory=list)
|
|||
|
|
config_summary: dict[str, Any] = Field(default_factory=dict)
|
|||
|
|
|
|||
|
|
|
|||
|
|
class RollbackResponse(BaseModel):
|
|||
|
|
"""策略回退响应。"""
|
|||
|
|
success: bool
|
|||
|
|
previous_strategy: str
|
|||
|
|
current_strategy: str
|
|||
|
|
trigger: str
|
|||
|
|
reason: str
|
|||
|
|
audit_log_id: str | None = None
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.get(
|
|||
|
|
"/current",
|
|||
|
|
response_model=StrategyStatusResponse,
|
|||
|
|
summary="获取当前策略状态",
|
|||
|
|
description="【AC-AISVC-RES-01】 获取当前活跃的检索策略配置状态。",
|
|||
|
|
)
|
|||
|
|
async def get_current_strategy() -> StrategyStatusResponse:
|
|||
|
|
"""
|
|||
|
|
【AC-AISVC-RES-01】 获取当前策略状态。
|
|||
|
|
"""
|
|||
|
|
strategy_router = get_strategy_router()
|
|||
|
|
config = strategy_router.get_config()
|
|||
|
|
|
|||
|
|
return StrategyStatusResponse(
|
|||
|
|
active_strategy=config.active_strategy.value,
|
|||
|
|
grayscale=GrayscaleConfigSchema(
|
|||
|
|
enabled=config.grayscale.enabled,
|
|||
|
|
percentage=config.grayscale.percentage,
|
|||
|
|
allowlist=config.grayscale.allowlist,
|
|||
|
|
),
|
|||
|
|
pipeline=PipelineConfigSchema(
|
|||
|
|
top_k=config.pipeline.top_k,
|
|||
|
|
score_threshold=config.pipeline.score_threshold,
|
|||
|
|
min_hits=config.pipeline.min_hits,
|
|||
|
|
two_stage_enabled=config.pipeline.two_stage_enabled,
|
|||
|
|
),
|
|||
|
|
reranker=RerankerConfigSchema(
|
|||
|
|
enabled=config.reranker.enabled,
|
|||
|
|
model=config.reranker.model,
|
|||
|
|
top_k_after_rerank=config.reranker.top_k_after_rerank,
|
|||
|
|
min_score_threshold=config.reranker.min_score_threshold,
|
|||
|
|
),
|
|||
|
|
mode_router=ModeRouterConfigSchema(
|
|||
|
|
runtime_mode=config.mode_router.runtime_mode.value,
|
|||
|
|
react_trigger_confidence_threshold=config.mode_router.react_trigger_confidence_threshold,
|
|||
|
|
react_trigger_complexity_score=config.mode_router.react_trigger_complexity_score,
|
|||
|
|
react_max_steps=config.mode_router.react_max_steps,
|
|||
|
|
direct_fallback_on_low_confidence=config.mode_router.direct_fallback_on_low_confidence,
|
|||
|
|
),
|
|||
|
|
performance_thresholds=config.performance_thresholds,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post(
|
|||
|
|
"/switch",
|
|||
|
|
response_model=SwitchResponse,
|
|||
|
|
summary="切换策略",
|
|||
|
|
description="【AC-AISVC-RES-02, AC-AISVC-RES-03】 切换检索策略, 支持灰度发布配置。",
|
|||
|
|
)
|
|||
|
|
async def switch_strategy(request: SwitchRequest) -> SwitchResponse:
|
|||
|
|
"""
|
|||
|
|
【AC-AISVC-RES-02, AC-AISVC-RES-03】 切换策略。
|
|||
|
|
|
|||
|
|
支持灰度发布配置(percentage/allowlist)。
|
|||
|
|
"""
|
|||
|
|
strategy_router = get_strategy_router()
|
|||
|
|
current_config = strategy_router.get_config()
|
|||
|
|
previous_strategy = current_config.active_strategy.value
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
new_active_strategy = StrategyType(request.active_strategy) if request.active_strategy else current_config.active_strategy
|
|||
|
|
except ValueError as e:
|
|||
|
|
raise HTTPException(
|
|||
|
|
status_code=400,
|
|||
|
|
detail=f"Invalid strategy type: {request.active_strategy}. Valid values are: default, enhanced",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
new_config = RetrievalStrategyConfig(
|
|||
|
|
active_strategy=new_active_strategy,
|
|||
|
|
grayscale=GrayscaleConfig(
|
|||
|
|
enabled=request.grayscale.enabled if request.grayscale else current_config.grayscale.enabled,
|
|||
|
|
percentage=request.grayscale.percentage if request.grayscale else current_config.grayscale.percentage,
|
|||
|
|
allowlist=request.grayscale.allowlist if request.grayscale else current_config.grayscale.allowlist,
|
|||
|
|
),
|
|||
|
|
mode_router=ModeRouterConfig(
|
|||
|
|
runtime_mode=RuntimeMode(request.mode_router.runtime_mode) if request.mode_router and request.mode_router.runtime_mode else current_config.mode_router.runtime_mode,
|
|||
|
|
react_trigger_confidence_threshold=request.mode_router.react_trigger_confidence_threshold if request.mode_router else current_config.mode_router.react_trigger_confidence_threshold,
|
|||
|
|
react_trigger_complexity_score=request.mode_router.react_trigger_complexity_score if request.mode_router else current_config.mode_router.react_trigger_complexity_score,
|
|||
|
|
react_max_steps=request.mode_router.react_max_steps if request.mode_router else current_config.mode_router.react_max_steps,
|
|||
|
|
direct_fallback_on_low_confidence=request.mode_router.direct_fallback_on_low_confidence if request.mode_router else current_config.mode_router.direct_fallback_on_low_confidence,
|
|||
|
|
) if request.mode_router else current_config.mode_router,
|
|||
|
|
reranker=RerankerConfig(
|
|||
|
|
enabled=request.reranker.enabled if request.reranker else current_config.reranker.enabled,
|
|||
|
|
model=request.reranker.model if request.reranker else current_config.reranker.model,
|
|||
|
|
top_k_after_rerank=request.reranker.top_k_after_rerank if request.reranker else current_config.reranker.top_k_after_rerank,
|
|||
|
|
min_score_threshold=request.reranker.min_score_threshold if request.reranker else current_config.reranker.min_score_threshold,
|
|||
|
|
) if request.reranker else current_config.reranker,
|
|||
|
|
pipeline=current_config.pipeline,
|
|||
|
|
metadata_inference=current_config.metadata_inference,
|
|||
|
|
performance_thresholds=current_config.performance_thresholds,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
strategy_router.update_config(new_config)
|
|||
|
|
|
|||
|
|
logger.info(
|
|||
|
|
f"[AC-AISVC-RES-02] Strategy switched: {previous_strategy} -> {new_config.active_strategy.value}"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return SwitchResponse(
|
|||
|
|
success=True,
|
|||
|
|
previous_strategy=previous_strategy,
|
|||
|
|
current_strategy=new_config.active_strategy.value,
|
|||
|
|
message=f"Strategy switched from {previous_strategy} to {new_config.active_strategy.value}",
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post(
|
|||
|
|
"/validate",
|
|||
|
|
response_model=ValidationResponse,
|
|||
|
|
summary="验证策略配置",
|
|||
|
|
description="【AC-AISVC-RES-06, AC-AISVC-RES-08】 验证策略配置的完整性与一致性。",
|
|||
|
|
)
|
|||
|
|
async def validate_strategy(request: ValidationRequest) -> ValidationResponse:
|
|||
|
|
"""
|
|||
|
|
【AC-AISVC-RES-06, AC-AISVC-RES-08】 验证策略配置。
|
|||
|
|
"""
|
|||
|
|
errors: list[str] = []
|
|||
|
|
warnings: list[str] = []
|
|||
|
|
config = request.config
|
|||
|
|
|
|||
|
|
if "active_strategy" in config:
|
|||
|
|
if config["active_strategy"] not in ["default", "enhanced"]:
|
|||
|
|
errors.append(f"Invalid active_strategy: {config['active_strategy']}")
|
|||
|
|
|
|||
|
|
if "grayscale" in config:
|
|||
|
|
grayscale = config["grayscale"]
|
|||
|
|
if grayscale.get("percentage", 1.0) is not None:
|
|||
|
|
if not (0 <= grayscale["percentage"] <= 100):
|
|||
|
|
errors.append(f"Invalid grayscale percentage: {grayscale['percentage']}")
|
|||
|
|
|
|||
|
|
if "mode_router" in config:
|
|||
|
|
mode_router = config["mode_router"]
|
|||
|
|
if "runtime_mode" in mode_router:
|
|||
|
|
if mode_router["runtime_mode"] not in ["direct", "react", "auto"]:
|
|||
|
|
errors.append(f"Invalid runtime_mode: {mode_router['runtime_mode']}")
|
|||
|
|
|
|||
|
|
if "reranker" in config:
|
|||
|
|
reranker = config["reranker"]
|
|||
|
|
if reranker.get("enabled") and reranker.get("top_k_after_rerank", 1) > 20:
|
|||
|
|
warnings.append(f"top_k_after_rerank should be between 1 and 20, current: {reranker['top_k_after_rerank']}")
|
|||
|
|
|
|||
|
|
if "performance_thresholds" in config:
|
|||
|
|
thresholds = config["performance_thresholds"]
|
|||
|
|
if thresholds.get("max_latency_ms", 0) is not None and thresholds["max_latency_ms"] < 100:
|
|||
|
|
warnings.append(f"max_latency_ms seems too low: {thresholds['max_latency_ms']}ms")
|
|||
|
|
|
|||
|
|
return ValidationResponse(
|
|||
|
|
valid=len(errors) == 0,
|
|||
|
|
errors=errors,
|
|||
|
|
warnings=warnings,
|
|||
|
|
config_summary={
|
|||
|
|
"active_strategy": config.get("active_strategy", "default"),
|
|||
|
|
"grayscale_enabled": config.get("grayscale", {}).get("enabled", False),
|
|||
|
|
"runtime_mode": config.get("mode_router", {}).get("runtime_mode", "direct"),
|
|||
|
|
"reranker_enabled": config.get("reranker", {}).get("enabled", False),
|
|||
|
|
"performance_thresholds": config.get("performance_thresholds", {}),
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
|
|||
|
|
@router.post(
|
|||
|
|
"/rollback",
|
|||
|
|
response_model=RollbackResponse,
|
|||
|
|
summary="回退策略",
|
|||
|
|
description="【AC-AISVC-RES-07】 回退到默认策略。",
|
|||
|
|
)
|
|||
|
|
async def rollback_strategy(
|
|||
|
|
trigger: str = "manual",
|
|||
|
|
reason: str = "",
|
|||
|
|
) -> RollbackResponse:
|
|||
|
|
"""
|
|||
|
|
【AC-AISVC-RES-07】 回退策略。
|
|||
|
|
|
|||
|
|
支持手动触发和自动触发(性能退化、异常)。
|
|||
|
|
"""
|
|||
|
|
strategy_router = get_strategy_router()
|
|||
|
|
current_config = strategy_router.get_config()
|
|||
|
|
previous_strategy = current_config.active_strategy
|
|||
|
|
|
|||
|
|
if previous_strategy == StrategyType.DEFAULT:
|
|||
|
|
return RollbackResponse(
|
|||
|
|
success=False,
|
|||
|
|
previous_strategy=previous_strategy.value,
|
|||
|
|
current_strategy=previous_strategy.value,
|
|||
|
|
trigger=trigger,
|
|||
|
|
reason="Already on default strategy",
|
|||
|
|
audit_log_id=None,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
new_config = RetrievalStrategyConfig(
|
|||
|
|
active_strategy=StrategyType.DEFAULT,
|
|||
|
|
grayscale=current_config.grayscale,
|
|||
|
|
mode_router=current_config.mode_router,
|
|||
|
|
reranker=current_config.reranker,
|
|||
|
|
pipeline=current_config.pipeline,
|
|||
|
|
metadata_inference=current_config.metadata_inference,
|
|||
|
|
performance_thresholds=current_config.performance_thresholds,
|
|||
|
|
)
|
|||
|
|
strategy_router.update_config(new_config)
|
|||
|
|
|
|||
|
|
rollback_manager = get_rollback_manager()
|
|||
|
|
rollback_manager.update_config(new_config)
|
|||
|
|
|
|||
|
|
audit_log = rollback_manager.record_audit(
|
|||
|
|
action="rollback",
|
|||
|
|
details={
|
|||
|
|
"from_strategy": previous_strategy.value,
|
|||
|
|
"to_strategy": StrategyType.DEFAULT.value,
|
|||
|
|
"trigger": trigger,
|
|||
|
|
"reason": reason or "Manual rollback",
|
|||
|
|
},
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
logger.info(
|
|||
|
|
f"[AC-AISVC-RES-07] Strategy rolled back: {previous_strategy.value} -> default"
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
return RollbackResponse(
|
|||
|
|
success=True,
|
|||
|
|
previous_strategy=previous_strategy.value,
|
|||
|
|
current_strategy=StrategyType.DEFAULT.value,
|
|||
|
|
trigger=trigger,
|
|||
|
|
reason=reason or "Manual rollback",
|
|||
|
|
audit_log_id=str(audit_log.timestamp) if audit_log else None,
|
|||
|
|
)
|