485 lines
17 KiB
Python
485 lines
17 KiB
Python
|
|
"""
|
||
|
|
Retrieval Strategy Service for AI Service.
|
||
|
|
[AC-AISVC-RES-01~15] Strategy management with grayscale and rollback support.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import logging
|
||
|
|
from dataclasses import dataclass, field
|
||
|
|
from datetime import datetime
|
||
|
|
from typing import Any
|
||
|
|
|
||
|
|
from app.schemas.retrieval_strategy import (
|
||
|
|
ReactMode,
|
||
|
|
RolloutConfig,
|
||
|
|
RolloutMode,
|
||
|
|
StrategyType,
|
||
|
|
RetrievalStrategyStatus,
|
||
|
|
RetrievalStrategySwitchRequest,
|
||
|
|
RetrievalStrategySwitchResponse,
|
||
|
|
RetrievalStrategyValidationRequest,
|
||
|
|
RetrievalStrategyValidationResponse,
|
||
|
|
RetrievalStrategyRollbackResponse,
|
||
|
|
ValidationResult,
|
||
|
|
)
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
|
||
|
|
|
||
|
|
@dataclass
|
||
|
|
class StrategyState:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-RES-01] Internal state for retrieval strategy.
|
||
|
|
"""
|
||
|
|
|
||
|
|
active_strategy: StrategyType = StrategyType.DEFAULT
|
||
|
|
react_mode: ReactMode = ReactMode.NON_REACT
|
||
|
|
rollout_mode: RolloutMode = RolloutMode.OFF
|
||
|
|
rollout_percentage: float = 0.0
|
||
|
|
rollout_allowlist: list[str] = field(default_factory=list)
|
||
|
|
previous_strategy: StrategyType | None = None
|
||
|
|
previous_react_mode: ReactMode | None = None
|
||
|
|
switch_history: list[dict[str, Any]] = field(default_factory=list)
|
||
|
|
|
||
|
|
|
||
|
|
class RetrievalStrategyService:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-RES-01~15] Service for managing retrieval strategies.
|
||
|
|
|
||
|
|
Features:
|
||
|
|
- Strategy switching with grayscale support
|
||
|
|
- Rollback to previous/default strategy
|
||
|
|
- Validation of strategy configuration
|
||
|
|
- Audit logging integration
|
||
|
|
"""
|
||
|
|
|
||
|
|
def __init__(self):
|
||
|
|
self._state = StrategyState()
|
||
|
|
self._audit_callback: Any = None
|
||
|
|
self._metrics_callback: Any = None
|
||
|
|
|
||
|
|
def set_audit_callback(self, callback: Any) -> None:
|
||
|
|
"""Set callback for audit logging."""
|
||
|
|
self._audit_callback = callback
|
||
|
|
|
||
|
|
def set_metrics_callback(self, callback: Any) -> None:
|
||
|
|
"""Set callback for metrics recording."""
|
||
|
|
self._metrics_callback = callback
|
||
|
|
|
||
|
|
def get_current_status(self) -> RetrievalStrategyStatus:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-RES-01] Get current retrieval strategy status.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
RetrievalStrategyStatus with current configuration.
|
||
|
|
"""
|
||
|
|
rollout = RolloutConfig(
|
||
|
|
mode=self._state.rollout_mode,
|
||
|
|
percentage=self._state.rollout_percentage if self._state.rollout_mode == RolloutMode.PERCENTAGE else None,
|
||
|
|
allowlist=self._state.rollout_allowlist if self._state.rollout_mode == RolloutMode.ALLOWLIST else None,
|
||
|
|
)
|
||
|
|
|
||
|
|
status = RetrievalStrategyStatus(
|
||
|
|
active_strategy=self._state.active_strategy,
|
||
|
|
react_mode=self._state.react_mode,
|
||
|
|
rollout=rollout,
|
||
|
|
)
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
f"[AC-AISVC-RES-01] Current strategy: {self._state.active_strategy.value}, "
|
||
|
|
f"react_mode={self._state.react_mode.value}, rollout={self._state.rollout_mode.value}"
|
||
|
|
)
|
||
|
|
|
||
|
|
return status
|
||
|
|
|
||
|
|
def switch_strategy(
|
||
|
|
self,
|
||
|
|
request: RetrievalStrategySwitchRequest,
|
||
|
|
operator: str | None = None,
|
||
|
|
tenant_id: str | None = None,
|
||
|
|
) -> RetrievalStrategySwitchResponse:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-RES-02, AC-AISVC-RES-03, AC-AISVC-RES-05] Switch retrieval strategy.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
request: Switch request with target strategy and options.
|
||
|
|
operator: Operator who initiated the switch.
|
||
|
|
tenant_id: Tenant ID for audit.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
RetrievalStrategySwitchResponse with previous and current status.
|
||
|
|
"""
|
||
|
|
previous_status = self.get_current_status()
|
||
|
|
|
||
|
|
self._state.previous_strategy = self._state.active_strategy
|
||
|
|
self._state.previous_react_mode = self._state.react_mode
|
||
|
|
|
||
|
|
self._state.active_strategy = request.target_strategy
|
||
|
|
|
||
|
|
if request.react_mode:
|
||
|
|
self._state.react_mode = request.react_mode
|
||
|
|
|
||
|
|
if request.rollout:
|
||
|
|
self._state.rollout_mode = request.rollout.mode
|
||
|
|
if request.rollout.mode == RolloutMode.PERCENTAGE:
|
||
|
|
self._state.rollout_percentage = request.rollout.percentage or 0.0
|
||
|
|
elif request.rollout.mode == RolloutMode.ALLOWLIST:
|
||
|
|
self._state.rollout_allowlist = request.rollout.allowlist or []
|
||
|
|
|
||
|
|
switch_record = {
|
||
|
|
"timestamp": datetime.utcnow().isoformat(),
|
||
|
|
"from_strategy": self._state.previous_strategy.value,
|
||
|
|
"to_strategy": self._state.active_strategy.value,
|
||
|
|
"react_mode": self._state.react_mode.value,
|
||
|
|
"rollout_mode": self._state.rollout_mode.value,
|
||
|
|
"reason": request.reason,
|
||
|
|
"operator": operator,
|
||
|
|
}
|
||
|
|
self._state.switch_history.append(switch_record)
|
||
|
|
|
||
|
|
current_status = self.get_current_status()
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
f"[AC-AISVC-RES-02] Strategy switched: {self._state.previous_strategy.value} -> "
|
||
|
|
f"{self._state.active_strategy.value}, react_mode={self._state.react_mode.value}"
|
||
|
|
)
|
||
|
|
|
||
|
|
if self._audit_callback:
|
||
|
|
self._audit_callback(
|
||
|
|
operation="switch",
|
||
|
|
previous_strategy=self._state.previous_strategy.value,
|
||
|
|
new_strategy=self._state.active_strategy.value,
|
||
|
|
previous_react_mode=self._state.previous_react_mode.value if self._state.previous_react_mode else None,
|
||
|
|
new_react_mode=self._state.react_mode.value,
|
||
|
|
reason=request.reason,
|
||
|
|
operator=operator,
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
if self._metrics_callback:
|
||
|
|
self._metrics_callback("strategy_switch", {
|
||
|
|
"from_strategy": self._state.previous_strategy.value,
|
||
|
|
"to_strategy": self._state.active_strategy.value,
|
||
|
|
})
|
||
|
|
|
||
|
|
return RetrievalStrategySwitchResponse(
|
||
|
|
previous=previous_status,
|
||
|
|
current=current_status,
|
||
|
|
)
|
||
|
|
|
||
|
|
def validate_strategy(
|
||
|
|
self,
|
||
|
|
request: RetrievalStrategyValidationRequest,
|
||
|
|
) -> RetrievalStrategyValidationResponse:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-RES-04, AC-AISVC-RES-06, AC-AISVC-RES-08] Validate strategy configuration.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
request: Validation request with strategy and checks.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
RetrievalStrategyValidationResponse with check results.
|
||
|
|
"""
|
||
|
|
results: list[ValidationResult] = []
|
||
|
|
default_checks = [
|
||
|
|
"metadata_consistency",
|
||
|
|
"embedding_prefix",
|
||
|
|
"rrf_config",
|
||
|
|
"performance_budget",
|
||
|
|
]
|
||
|
|
checks_to_run = request.checks if request.checks else default_checks
|
||
|
|
|
||
|
|
for check in checks_to_run:
|
||
|
|
result = self._run_validation_check(check, request.strategy, request.react_mode)
|
||
|
|
results.append(result)
|
||
|
|
|
||
|
|
all_passed = all(r.passed for r in results)
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
f"[AC-AISVC-RES-06] Strategy validation: strategy={request.strategy.value}, "
|
||
|
|
f"checks={len(results)}, passed={all_passed}"
|
||
|
|
)
|
||
|
|
|
||
|
|
return RetrievalStrategyValidationResponse(
|
||
|
|
passed=all_passed,
|
||
|
|
results=results,
|
||
|
|
)
|
||
|
|
|
||
|
|
def _run_validation_check(
|
||
|
|
self,
|
||
|
|
check: str,
|
||
|
|
strategy: StrategyType,
|
||
|
|
react_mode: ReactMode | None,
|
||
|
|
) -> ValidationResult:
|
||
|
|
"""
|
||
|
|
Run a single validation check.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
check: Check name.
|
||
|
|
strategy: Strategy to validate.
|
||
|
|
react_mode: ReAct mode to validate.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
ValidationResult for the check.
|
||
|
|
"""
|
||
|
|
if check == "metadata_consistency":
|
||
|
|
return self._check_metadata_consistency(strategy)
|
||
|
|
elif check == "embedding_prefix":
|
||
|
|
return self._check_embedding_prefix(strategy)
|
||
|
|
elif check == "rrf_config":
|
||
|
|
return self._check_rrf_config(strategy)
|
||
|
|
elif check == "performance_budget":
|
||
|
|
return self._check_performance_budget(strategy, react_mode)
|
||
|
|
else:
|
||
|
|
return ValidationResult(
|
||
|
|
check=check,
|
||
|
|
passed=False,
|
||
|
|
message=f"Unknown check type: {check}",
|
||
|
|
)
|
||
|
|
|
||
|
|
def _check_metadata_consistency(self, strategy: StrategyType) -> ValidationResult:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-RES-04] Check metadata consistency between strategies.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
passed = True
|
||
|
|
message = "Metadata consistency check passed"
|
||
|
|
|
||
|
|
logger.debug(f"[AC-AISVC-RES-04] Metadata consistency check: strategy={strategy.value}, passed={passed}")
|
||
|
|
return ValidationResult(check="metadata_consistency", passed=passed, message=message)
|
||
|
|
except Exception as e:
|
||
|
|
return ValidationResult(check="metadata_consistency", passed=False, message=str(e))
|
||
|
|
|
||
|
|
def _check_embedding_prefix(self, strategy: StrategyType) -> ValidationResult:
|
||
|
|
"""
|
||
|
|
Check embedding prefix configuration.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
passed = True
|
||
|
|
message = "Embedding prefix configuration valid"
|
||
|
|
|
||
|
|
logger.debug(f"[AC-AISVC-RES-04] Embedding prefix check: strategy={strategy.value}, passed={passed}")
|
||
|
|
return ValidationResult(check="embedding_prefix", passed=passed, message=message)
|
||
|
|
except Exception as e:
|
||
|
|
return ValidationResult(check="embedding_prefix", passed=False, message=str(e))
|
||
|
|
|
||
|
|
def _check_rrf_config(self, strategy: StrategyType) -> ValidationResult:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-RES-02] Check RRF (Reciprocal Rank Fusion) configuration.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
from app.core.config import get_settings
|
||
|
|
|
||
|
|
settings = get_settings()
|
||
|
|
|
||
|
|
if strategy == StrategyType.ENHANCED:
|
||
|
|
if not settings.rag_hybrid_enabled:
|
||
|
|
return ValidationResult(
|
||
|
|
check="rrf_config",
|
||
|
|
passed=False,
|
||
|
|
message="Hybrid retrieval not enabled for enhanced strategy",
|
||
|
|
)
|
||
|
|
|
||
|
|
if settings.rag_rrf_k <= 0:
|
||
|
|
return ValidationResult(
|
||
|
|
check="rrf_config",
|
||
|
|
passed=False,
|
||
|
|
message="RRF K parameter must be positive",
|
||
|
|
)
|
||
|
|
|
||
|
|
return ValidationResult(check="rrf_config", passed=True, message="RRF configuration valid")
|
||
|
|
except Exception as e:
|
||
|
|
return ValidationResult(check="rrf_config", passed=False, message=str(e))
|
||
|
|
|
||
|
|
def _check_performance_budget(
|
||
|
|
self,
|
||
|
|
strategy: StrategyType,
|
||
|
|
react_mode: ReactMode | None,
|
||
|
|
) -> ValidationResult:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-RES-08] Check performance budget constraints.
|
||
|
|
"""
|
||
|
|
try:
|
||
|
|
max_latency_ms = 5000
|
||
|
|
|
||
|
|
if strategy == StrategyType.ENHANCED and react_mode == ReactMode.REACT:
|
||
|
|
max_latency_ms = 10000
|
||
|
|
|
||
|
|
message = f"Performance budget check passed (max_latency={max_latency_ms}ms)"
|
||
|
|
|
||
|
|
logger.debug(
|
||
|
|
f"[AC-AISVC-RES-08] Performance budget check: strategy={strategy.value}, "
|
||
|
|
f"react_mode={react_mode}, max_latency={max_latency_ms}ms"
|
||
|
|
)
|
||
|
|
return ValidationResult(check="performance_budget", passed=True, message=message)
|
||
|
|
except Exception as e:
|
||
|
|
return ValidationResult(check="performance_budget", passed=False, message=str(e))
|
||
|
|
|
||
|
|
def rollback_strategy(
|
||
|
|
self,
|
||
|
|
operator: str | None = None,
|
||
|
|
tenant_id: str | None = None,
|
||
|
|
) -> RetrievalStrategyRollbackResponse:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-RES-07] Rollback to previous or default strategy.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
operator: Operator who initiated the rollback.
|
||
|
|
tenant_id: Tenant ID for audit.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
RetrievalStrategyRollbackResponse with current and rollback status.
|
||
|
|
"""
|
||
|
|
current_status = self.get_current_status()
|
||
|
|
|
||
|
|
rollback_to_strategy = self._state.previous_strategy or StrategyType.DEFAULT
|
||
|
|
rollback_to_react_mode = self._state.previous_react_mode or ReactMode.NON_REACT
|
||
|
|
|
||
|
|
old_strategy = self._state.active_strategy
|
||
|
|
old_react_mode = self._state.react_mode
|
||
|
|
|
||
|
|
self._state.active_strategy = rollback_to_strategy
|
||
|
|
self._state.react_mode = rollback_to_react_mode
|
||
|
|
self._state.rollout_mode = RolloutMode.OFF
|
||
|
|
self._state.rollout_percentage = 0.0
|
||
|
|
self._state.rollout_allowlist = []
|
||
|
|
|
||
|
|
rollback_status = self.get_current_status()
|
||
|
|
|
||
|
|
rollback_record = {
|
||
|
|
"timestamp": datetime.utcnow().isoformat(),
|
||
|
|
"from_strategy": old_strategy.value,
|
||
|
|
"to_strategy": rollback_to_strategy.value,
|
||
|
|
"operator": operator,
|
||
|
|
}
|
||
|
|
self._state.switch_history.append(rollback_record)
|
||
|
|
|
||
|
|
logger.info(
|
||
|
|
f"[AC-AISVC-RES-07] Strategy rolled back: {old_strategy.value} -> "
|
||
|
|
f"{rollback_to_strategy.value}, react_mode={rollback_to_react_mode.value}"
|
||
|
|
)
|
||
|
|
|
||
|
|
if self._audit_callback:
|
||
|
|
self._audit_callback(
|
||
|
|
operation="rollback",
|
||
|
|
previous_strategy=old_strategy.value,
|
||
|
|
new_strategy=rollback_to_strategy.value,
|
||
|
|
previous_react_mode=old_react_mode.value,
|
||
|
|
new_react_mode=rollback_to_react_mode.value,
|
||
|
|
reason="Manual rollback",
|
||
|
|
operator=operator,
|
||
|
|
tenant_id=tenant_id,
|
||
|
|
)
|
||
|
|
|
||
|
|
if self._metrics_callback:
|
||
|
|
self._metrics_callback("strategy_rollback", {
|
||
|
|
"from_strategy": old_strategy.value,
|
||
|
|
"to_strategy": rollback_to_strategy.value,
|
||
|
|
})
|
||
|
|
|
||
|
|
return RetrievalStrategyRollbackResponse(
|
||
|
|
current=current_status,
|
||
|
|
rollback_to=rollback_status,
|
||
|
|
)
|
||
|
|
|
||
|
|
def should_use_enhanced_strategy(self, tenant_id: str | None = None) -> bool:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-RES-03] Determine if enhanced strategy should be used based on rollout config.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
tenant_id: Tenant ID for allowlist check.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
True if enhanced strategy should be used.
|
||
|
|
"""
|
||
|
|
if self._state.active_strategy == StrategyType.DEFAULT:
|
||
|
|
return False
|
||
|
|
|
||
|
|
if self._state.rollout_mode == RolloutMode.OFF:
|
||
|
|
return self._state.active_strategy == StrategyType.ENHANCED
|
||
|
|
|
||
|
|
if self._state.rollout_mode == RolloutMode.ALLOWLIST:
|
||
|
|
if tenant_id and tenant_id in self._state.rollout_allowlist:
|
||
|
|
return True
|
||
|
|
return False
|
||
|
|
|
||
|
|
if self._state.rollout_mode == RolloutMode.PERCENTAGE:
|
||
|
|
import random
|
||
|
|
return random.random() * 100 < self._state.rollout_percentage
|
||
|
|
|
||
|
|
return False
|
||
|
|
|
||
|
|
def get_route_mode(
|
||
|
|
self,
|
||
|
|
query: str,
|
||
|
|
confidence: float | None = None,
|
||
|
|
) -> str:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-RES-09~15] Determine route mode based on query and confidence.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
query: User query.
|
||
|
|
confidence: Confidence score from metadata inference.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
Route mode: "direct", "react", or "auto".
|
||
|
|
"""
|
||
|
|
if self._state.react_mode == ReactMode.REACT:
|
||
|
|
return "react"
|
||
|
|
elif self._state.react_mode == ReactMode.NON_REACT:
|
||
|
|
return "direct"
|
||
|
|
else:
|
||
|
|
return self._auto_route(query, confidence)
|
||
|
|
|
||
|
|
def _auto_route(self, query: str, confidence: float | None = None) -> str:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-RES-11~14] Auto route based on query complexity and confidence.
|
||
|
|
"""
|
||
|
|
query_length = len(query)
|
||
|
|
has_multiple_conditions = "和" in query or "或" in query or "以及" in query
|
||
|
|
|
||
|
|
low_confidence_threshold = 0.5
|
||
|
|
short_query_threshold = 20
|
||
|
|
|
||
|
|
if confidence is not None and confidence < low_confidence_threshold:
|
||
|
|
logger.info(
|
||
|
|
f"[AC-AISVC-RES-13] Auto route to react: low confidence={confidence}"
|
||
|
|
)
|
||
|
|
return "react"
|
||
|
|
|
||
|
|
if has_multiple_conditions:
|
||
|
|
logger.info(
|
||
|
|
f"[AC-AISVC-RES-13] Auto route to react: multiple conditions detected"
|
||
|
|
)
|
||
|
|
return "react"
|
||
|
|
|
||
|
|
if query_length < short_query_threshold and confidence and confidence > 0.7:
|
||
|
|
logger.info(
|
||
|
|
f"[AC-AISVC-RES-12] Auto route to direct: short query, high confidence"
|
||
|
|
)
|
||
|
|
return "direct"
|
||
|
|
|
||
|
|
return "direct"
|
||
|
|
|
||
|
|
def get_switch_history(self, limit: int = 10) -> list[dict[str, Any]]:
|
||
|
|
"""
|
||
|
|
Get recent switch history.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
limit: Maximum number of records to return.
|
||
|
|
|
||
|
|
Returns:
|
||
|
|
List of switch records.
|
||
|
|
"""
|
||
|
|
return self._state.switch_history[-limit:]
|
||
|
|
|
||
|
|
|
||
|
|
_strategy_service: RetrievalStrategyService | None = None
|
||
|
|
|
||
|
|
|
||
|
|
def get_strategy_service() -> RetrievalStrategyService:
|
||
|
|
"""Get or create RetrievalStrategyService instance."""
|
||
|
|
global _strategy_service
|
||
|
|
if _strategy_service is None:
|
||
|
|
_strategy_service = RetrievalStrategyService()
|
||
|
|
return _strategy_service
|