ai-robot-core/ai-service/app/api/admin/strategy.py

343 lines
14 KiB
Python

"""
Retrieval Strategy API Endpoints.
[AC-AISVC-RES-01~15] API for strategy management and configuration.
Endpoints:
- GET /strategy/retrieval/current - Get current strategy configuration
- POST /strategy/retrieval/switch - Switch strategy configuration
- POST /strategy/retrieval/validate - Validate strategy configuration
- POST /strategy/retrieval/rollback - Rollback to default strategy
"""
import logging
from typing import Any
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session
from app.core.tenant import get_tenant_id
from app.services.retrieval.routing_config import (
RagRuntimeMode,
StrategyType,
RoutingConfig,
)
from app.services.retrieval.strategy_router import (
get_strategy_router,
RollbackRecord,
)
from app.services.retrieval.mode_router import get_mode_router
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/strategy/retrieval", tags=["Retrieval Strategy"])
class RoutingConfigRequest(BaseModel):
"""Request model for routing configuration."""
enabled: bool | None = Field(default=None, description="Enable strategy routing")
strategy: StrategyType | None = Field(default=None, description="Retrieval strategy")
grayscale_percentage: float | None = Field(default=None, ge=0.0, le=1.0, description="Grayscale percentage")
grayscale_allowlist: list[str] | None = Field(default=None, description="Grayscale allowlist")
rag_runtime_mode: RagRuntimeMode | None = Field(default=None, description="RAG runtime mode")
react_trigger_confidence_threshold: float | None = Field(default=None, ge=0.0, le=1.0)
react_trigger_complexity_score: float | None = Field(default=None, ge=0.0, le=1.0)
react_max_steps: int | None = Field(default=None, ge=3, le=10)
direct_fallback_on_low_confidence: bool | None = Field(default=None)
direct_fallback_confidence_threshold: float | None = Field(default=None, ge=0.0, le=1.0)
performance_budget_ms: int | None = Field(default=None, ge=1000)
performance_degradation_threshold: float | None = Field(default=None, ge=0.0, le=1.0)
class RoutingConfigResponse(BaseModel):
"""Response model for routing configuration."""
enabled: bool
strategy: StrategyType
grayscale_percentage: float
grayscale_allowlist: list[str]
rag_runtime_mode: RagRuntimeMode
react_trigger_confidence_threshold: float
react_trigger_complexity_score: float
react_max_steps: int
direct_fallback_on_low_confidence: bool
direct_fallback_confidence_threshold: float
performance_budget_ms: int
performance_degradation_threshold: float
class ValidationResponse(BaseModel):
"""Response model for configuration validation."""
is_valid: bool
errors: list[str]
warnings: list[str] = Field(default_factory=list)
class RollbackRequest(BaseModel):
"""Request model for strategy rollback."""
reason: str = Field(..., description="Reason for rollback")
tenant_id: str | None = Field(default=None, description="Optional tenant ID for audit")
class RollbackResponse(BaseModel):
"""Response model for strategy rollback."""
success: bool
previous_strategy: StrategyType
current_strategy: StrategyType
reason: str
rollback_records: list[dict[str, Any]] = Field(default_factory=list)
class RollbackRecordResponse(BaseModel):
"""Response model for rollback record."""
timestamp: float
from_strategy: StrategyType
to_strategy: StrategyType
reason: str
tenant_id: str | None
request_id: str | None
class CurrentStrategyResponse(BaseModel):
"""Response model for current strategy."""
config: RoutingConfigResponse
current_strategy: StrategyType
rollback_records: list[RollbackRecordResponse] = Field(default_factory=list)
@router.get(
"/current",
operation_id="getCurrentRetrievalStrategy",
summary="Get current retrieval strategy configuration",
description="[AC-AISVC-RES-01] Returns the current strategy configuration and recent rollback records.",
response_model=CurrentStrategyResponse,
)
async def get_current_strategy() -> CurrentStrategyResponse:
"""
[AC-AISVC-RES-01] Get current retrieval strategy configuration.
"""
strategy_router = get_strategy_router()
config = strategy_router.config
rollback_records = strategy_router.get_rollback_records(limit=5)
return CurrentStrategyResponse(
config=RoutingConfigResponse(
enabled=config.enabled,
strategy=config.strategy,
grayscale_percentage=config.grayscale_percentage,
grayscale_allowlist=config.grayscale_allowlist,
rag_runtime_mode=config.rag_runtime_mode,
react_trigger_confidence_threshold=config.react_trigger_confidence_threshold,
react_trigger_complexity_score=config.react_trigger_complexity_score,
react_max_steps=config.react_max_steps,
direct_fallback_on_low_confidence=config.direct_fallback_on_low_confidence,
direct_fallback_confidence_threshold=config.direct_fallback_confidence_threshold,
performance_budget_ms=config.performance_budget_ms,
performance_degradation_threshold=config.performance_degradation_threshold,
),
current_strategy=strategy_router.current_strategy,
rollback_records=[
RollbackRecordResponse(
timestamp=r.timestamp,
from_strategy=r.from_strategy,
to_strategy=r.to_strategy,
reason=r.reason,
tenant_id=r.tenant_id,
request_id=r.request_id,
)
for r in rollback_records
],
)
@router.post(
"/switch",
operation_id="switchRetrievalStrategy",
summary="Switch retrieval strategy configuration",
description="[AC-AISVC-RES-02, AC-AISVC-RES-03] Update strategy configuration with hot reload support.",
response_model=RoutingConfigResponse,
)
async def switch_strategy(
request: RoutingConfigRequest,
session: AsyncSession = Depends(get_session),
) -> RoutingConfigResponse:
"""
[AC-AISVC-RES-02, AC-AISVC-RES-03, AC-AISVC-RES-15]
Switch retrieval strategy configuration.
Supports:
- Strategy selection (default/enhanced)
- Grayscale release (percentage/allowlist)
- Mode selection (direct/react/auto)
- Hot reload of routing parameters
"""
strategy_router = get_strategy_router()
mode_router = get_mode_router()
current_config = strategy_router.config
new_config = RoutingConfig(
enabled=request.enabled if request.enabled is not None else current_config.enabled,
strategy=request.strategy if request.strategy is not None else current_config.strategy,
grayscale_percentage=request.grayscale_percentage if request.grayscale_percentage is not None else current_config.grayscale_percentage,
grayscale_allowlist=request.grayscale_allowlist if request.grayscale_allowlist is not None else current_config.grayscale_allowlist,
rag_runtime_mode=request.rag_runtime_mode if request.rag_runtime_mode is not None else current_config.rag_runtime_mode,
react_trigger_confidence_threshold=request.react_trigger_confidence_threshold if request.react_trigger_confidence_threshold is not None else current_config.react_trigger_confidence_threshold,
react_trigger_complexity_score=request.react_trigger_complexity_score if request.react_trigger_complexity_score is not None else current_config.react_trigger_complexity_score,
react_max_steps=request.react_max_steps if request.react_max_steps is not None else current_config.react_max_steps,
direct_fallback_on_low_confidence=request.direct_fallback_on_low_confidence if request.direct_fallback_on_low_confidence is not None else current_config.direct_fallback_on_low_confidence,
direct_fallback_confidence_threshold=request.direct_fallback_confidence_threshold if request.direct_fallback_confidence_threshold is not None else current_config.direct_fallback_confidence_threshold,
performance_budget_ms=request.performance_budget_ms if request.performance_budget_ms is not None else current_config.performance_budget_ms,
performance_degradation_threshold=request.performance_degradation_threshold if request.performance_degradation_threshold is not None else current_config.performance_degradation_threshold,
)
is_valid, errors = new_config.validate()
if not is_valid:
raise HTTPException(
status_code=400,
detail={"errors": errors},
)
strategy_router.update_config(new_config)
mode_router.update_config(new_config)
logger.info(
f"[AC-AISVC-RES-02, AC-AISVC-RES-15] Strategy switched: "
f"strategy={new_config.strategy.value}, mode={new_config.rag_runtime_mode.value}"
)
return RoutingConfigResponse(
enabled=new_config.enabled,
strategy=new_config.strategy,
grayscale_percentage=new_config.grayscale_percentage,
grayscale_allowlist=new_config.grayscale_allowlist,
rag_runtime_mode=new_config.rag_runtime_mode,
react_trigger_confidence_threshold=new_config.react_trigger_confidence_threshold,
react_trigger_complexity_score=new_config.react_trigger_complexity_score,
react_max_steps=new_config.react_max_steps,
direct_fallback_on_low_confidence=new_config.direct_fallback_on_low_confidence,
direct_fallback_confidence_threshold=new_config.direct_fallback_confidence_threshold,
performance_budget_ms=new_config.performance_budget_ms,
performance_degradation_threshold=new_config.performance_degradation_threshold,
)
@router.post(
"/validate",
operation_id="validateRetrievalStrategy",
summary="Validate retrieval strategy configuration",
description="[AC-AISVC-RES-06] Validate strategy configuration for completeness and consistency.",
response_model=ValidationResponse,
)
async def validate_strategy(
request: RoutingConfigRequest,
) -> ValidationResponse:
"""
[AC-AISVC-RES-06] Validate strategy configuration.
Checks:
- Parameter value ranges
- Configuration consistency
- Performance budget constraints
"""
warnings = []
config = RoutingConfig(
enabled=request.enabled if request.enabled is not None else True,
strategy=request.strategy if request.strategy is not None else StrategyType.DEFAULT,
grayscale_percentage=request.grayscale_percentage if request.grayscale_percentage is not None else 0.0,
grayscale_allowlist=request.grayscale_allowlist or [],
rag_runtime_mode=request.rag_runtime_mode if request.rag_runtime_mode is not None else RagRuntimeMode.AUTO,
react_trigger_confidence_threshold=request.react_trigger_confidence_threshold if request.react_trigger_confidence_threshold is not None else 0.6,
react_trigger_complexity_score=request.react_trigger_complexity_score if request.react_trigger_complexity_score is not None else 0.5,
react_max_steps=request.react_max_steps if request.react_max_steps is not None else 5,
direct_fallback_on_low_confidence=request.direct_fallback_on_low_confidence if request.direct_fallback_on_low_confidence is not None else True,
direct_fallback_confidence_threshold=request.direct_fallback_confidence_threshold if request.direct_fallback_confidence_threshold is not None else 0.4,
performance_budget_ms=request.performance_budget_ms if request.performance_budget_ms is not None else 5000,
performance_degradation_threshold=request.performance_degradation_threshold if request.performance_degradation_threshold is not None else 0.2,
)
is_valid, errors = config.validate()
if config.grayscale_percentage > 0.5:
warnings.append("grayscale_percentage > 50% may affect production stability")
if config.react_max_steps > 7:
warnings.append("react_max_steps > 7 may cause high latency")
if config.direct_fallback_confidence_threshold > config.react_trigger_confidence_threshold:
warnings.append(
"direct_fallback_confidence_threshold > react_trigger_confidence_threshold "
"may cause frequent fallbacks"
)
if config.performance_budget_ms < 3000:
warnings.append("performance_budget_ms < 3000ms may be too aggressive for complex queries")
logger.info(
f"[AC-AISVC-RES-06] Strategy validation: is_valid={is_valid}, "
f"errors={len(errors)}, warnings={len(warnings)}"
)
return ValidationResponse(
is_valid=is_valid,
errors=errors,
warnings=warnings,
)
@router.post(
"/rollback",
operation_id="rollbackRetrievalStrategy",
summary="Rollback to default retrieval strategy",
description="[AC-AISVC-RES-07] Force rollback to default strategy with audit logging.",
response_model=RollbackResponse,
)
async def rollback_strategy(
request: RollbackRequest,
session: AsyncSession = Depends(get_session),
) -> RollbackResponse:
"""
[AC-AISVC-RES-07] Rollback to default strategy.
Records the rollback event for audit and monitoring.
"""
strategy_router = get_strategy_router()
mode_router = get_mode_router()
previous_strategy = strategy_router.current_strategy
strategy_router.rollback(
reason=request.reason,
tenant_id=request.tenant_id,
)
new_config = RoutingConfig(
strategy=StrategyType.DEFAULT,
rag_runtime_mode=RagRuntimeMode.AUTO,
)
mode_router.update_config(new_config)
rollback_records = strategy_router.get_rollback_records(limit=5)
logger.info(
f"[AC-AISVC-RES-07] Strategy rollback: "
f"from={previous_strategy.value}, to=DEFAULT, reason={request.reason}"
)
return RollbackResponse(
success=True,
previous_strategy=previous_strategy,
current_strategy=StrategyType.DEFAULT,
reason=request.reason,
rollback_records=[
{
"timestamp": r.timestamp,
"from_strategy": r.from_strategy.value,
"to_strategy": r.to_strategy.value,
"reason": r.reason,
"tenant_id": r.tenant_id,
}
for r in rollback_records
],
)