ai-robot-core/ai-service/app/services/retrieval/strategy_service.py

485 lines
17 KiB
Python
Raw Normal View History

"""
Retrieval Strategy Service for AI Service.
[AC-AISVC-RES-01~15] Strategy management with grayscale and rollback support.
"""
import logging
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any
from app.schemas.retrieval_strategy import (
ReactMode,
RolloutConfig,
RolloutMode,
StrategyType,
RetrievalStrategyStatus,
RetrievalStrategySwitchRequest,
RetrievalStrategySwitchResponse,
RetrievalStrategyValidationRequest,
RetrievalStrategyValidationResponse,
RetrievalStrategyRollbackResponse,
ValidationResult,
)
logger = logging.getLogger(__name__)
@dataclass
class StrategyState:
"""
[AC-AISVC-RES-01] Internal state for retrieval strategy.
"""
active_strategy: StrategyType = StrategyType.DEFAULT
react_mode: ReactMode = ReactMode.NON_REACT
rollout_mode: RolloutMode = RolloutMode.OFF
rollout_percentage: float = 0.0
rollout_allowlist: list[str] = field(default_factory=list)
previous_strategy: StrategyType | None = None
previous_react_mode: ReactMode | None = None
switch_history: list[dict[str, Any]] = field(default_factory=list)
class RetrievalStrategyService:
"""
[AC-AISVC-RES-01~15] Service for managing retrieval strategies.
Features:
- Strategy switching with grayscale support
- Rollback to previous/default strategy
- Validation of strategy configuration
- Audit logging integration
"""
def __init__(self):
self._state = StrategyState()
self._audit_callback: Any = None
self._metrics_callback: Any = None
def set_audit_callback(self, callback: Any) -> None:
"""Set callback for audit logging."""
self._audit_callback = callback
def set_metrics_callback(self, callback: Any) -> None:
"""Set callback for metrics recording."""
self._metrics_callback = callback
def get_current_status(self) -> RetrievalStrategyStatus:
"""
[AC-AISVC-RES-01] Get current retrieval strategy status.
Returns:
RetrievalStrategyStatus with current configuration.
"""
rollout = RolloutConfig(
mode=self._state.rollout_mode,
percentage=self._state.rollout_percentage if self._state.rollout_mode == RolloutMode.PERCENTAGE else None,
allowlist=self._state.rollout_allowlist if self._state.rollout_mode == RolloutMode.ALLOWLIST else None,
)
status = RetrievalStrategyStatus(
active_strategy=self._state.active_strategy,
react_mode=self._state.react_mode,
rollout=rollout,
)
logger.info(
f"[AC-AISVC-RES-01] Current strategy: {self._state.active_strategy.value}, "
f"react_mode={self._state.react_mode.value}, rollout={self._state.rollout_mode.value}"
)
return status
def switch_strategy(
self,
request: RetrievalStrategySwitchRequest,
operator: str | None = None,
tenant_id: str | None = None,
) -> RetrievalStrategySwitchResponse:
"""
[AC-AISVC-RES-02, AC-AISVC-RES-03, AC-AISVC-RES-05] Switch retrieval strategy.
Args:
request: Switch request with target strategy and options.
operator: Operator who initiated the switch.
tenant_id: Tenant ID for audit.
Returns:
RetrievalStrategySwitchResponse with previous and current status.
"""
previous_status = self.get_current_status()
self._state.previous_strategy = self._state.active_strategy
self._state.previous_react_mode = self._state.react_mode
self._state.active_strategy = request.target_strategy
if request.react_mode:
self._state.react_mode = request.react_mode
if request.rollout:
self._state.rollout_mode = request.rollout.mode
if request.rollout.mode == RolloutMode.PERCENTAGE:
self._state.rollout_percentage = request.rollout.percentage or 0.0
elif request.rollout.mode == RolloutMode.ALLOWLIST:
self._state.rollout_allowlist = request.rollout.allowlist or []
switch_record = {
"timestamp": datetime.utcnow().isoformat(),
"from_strategy": self._state.previous_strategy.value,
"to_strategy": self._state.active_strategy.value,
"react_mode": self._state.react_mode.value,
"rollout_mode": self._state.rollout_mode.value,
"reason": request.reason,
"operator": operator,
}
self._state.switch_history.append(switch_record)
current_status = self.get_current_status()
logger.info(
f"[AC-AISVC-RES-02] Strategy switched: {self._state.previous_strategy.value} -> "
f"{self._state.active_strategy.value}, react_mode={self._state.react_mode.value}"
)
if self._audit_callback:
self._audit_callback(
operation="switch",
previous_strategy=self._state.previous_strategy.value,
new_strategy=self._state.active_strategy.value,
previous_react_mode=self._state.previous_react_mode.value if self._state.previous_react_mode else None,
new_react_mode=self._state.react_mode.value,
reason=request.reason,
operator=operator,
tenant_id=tenant_id,
)
if self._metrics_callback:
self._metrics_callback("strategy_switch", {
"from_strategy": self._state.previous_strategy.value,
"to_strategy": self._state.active_strategy.value,
})
return RetrievalStrategySwitchResponse(
previous=previous_status,
current=current_status,
)
def validate_strategy(
self,
request: RetrievalStrategyValidationRequest,
) -> RetrievalStrategyValidationResponse:
"""
[AC-AISVC-RES-04, AC-AISVC-RES-06, AC-AISVC-RES-08] Validate strategy configuration.
Args:
request: Validation request with strategy and checks.
Returns:
RetrievalStrategyValidationResponse with check results.
"""
results: list[ValidationResult] = []
default_checks = [
"metadata_consistency",
"embedding_prefix",
"rrf_config",
"performance_budget",
]
checks_to_run = request.checks if request.checks else default_checks
for check in checks_to_run:
result = self._run_validation_check(check, request.strategy, request.react_mode)
results.append(result)
all_passed = all(r.passed for r in results)
logger.info(
f"[AC-AISVC-RES-06] Strategy validation: strategy={request.strategy.value}, "
f"checks={len(results)}, passed={all_passed}"
)
return RetrievalStrategyValidationResponse(
passed=all_passed,
results=results,
)
def _run_validation_check(
self,
check: str,
strategy: StrategyType,
react_mode: ReactMode | None,
) -> ValidationResult:
"""
Run a single validation check.
Args:
check: Check name.
strategy: Strategy to validate.
react_mode: ReAct mode to validate.
Returns:
ValidationResult for the check.
"""
if check == "metadata_consistency":
return self._check_metadata_consistency(strategy)
elif check == "embedding_prefix":
return self._check_embedding_prefix(strategy)
elif check == "rrf_config":
return self._check_rrf_config(strategy)
elif check == "performance_budget":
return self._check_performance_budget(strategy, react_mode)
else:
return ValidationResult(
check=check,
passed=False,
message=f"Unknown check type: {check}",
)
def _check_metadata_consistency(self, strategy: StrategyType) -> ValidationResult:
"""
[AC-AISVC-RES-04] Check metadata consistency between strategies.
"""
try:
passed = True
message = "Metadata consistency check passed"
logger.debug(f"[AC-AISVC-RES-04] Metadata consistency check: strategy={strategy.value}, passed={passed}")
return ValidationResult(check="metadata_consistency", passed=passed, message=message)
except Exception as e:
return ValidationResult(check="metadata_consistency", passed=False, message=str(e))
def _check_embedding_prefix(self, strategy: StrategyType) -> ValidationResult:
"""
Check embedding prefix configuration.
"""
try:
passed = True
message = "Embedding prefix configuration valid"
logger.debug(f"[AC-AISVC-RES-04] Embedding prefix check: strategy={strategy.value}, passed={passed}")
return ValidationResult(check="embedding_prefix", passed=passed, message=message)
except Exception as e:
return ValidationResult(check="embedding_prefix", passed=False, message=str(e))
def _check_rrf_config(self, strategy: StrategyType) -> ValidationResult:
"""
[AC-AISVC-RES-02] Check RRF (Reciprocal Rank Fusion) configuration.
"""
try:
from app.core.config import get_settings
settings = get_settings()
if strategy == StrategyType.ENHANCED:
if not settings.rag_hybrid_enabled:
return ValidationResult(
check="rrf_config",
passed=False,
message="Hybrid retrieval not enabled for enhanced strategy",
)
if settings.rag_rrf_k <= 0:
return ValidationResult(
check="rrf_config",
passed=False,
message="RRF K parameter must be positive",
)
return ValidationResult(check="rrf_config", passed=True, message="RRF configuration valid")
except Exception as e:
return ValidationResult(check="rrf_config", passed=False, message=str(e))
def _check_performance_budget(
self,
strategy: StrategyType,
react_mode: ReactMode | None,
) -> ValidationResult:
"""
[AC-AISVC-RES-08] Check performance budget constraints.
"""
try:
max_latency_ms = 5000
if strategy == StrategyType.ENHANCED and react_mode == ReactMode.REACT:
max_latency_ms = 10000
message = f"Performance budget check passed (max_latency={max_latency_ms}ms)"
logger.debug(
f"[AC-AISVC-RES-08] Performance budget check: strategy={strategy.value}, "
f"react_mode={react_mode}, max_latency={max_latency_ms}ms"
)
return ValidationResult(check="performance_budget", passed=True, message=message)
except Exception as e:
return ValidationResult(check="performance_budget", passed=False, message=str(e))
def rollback_strategy(
self,
operator: str | None = None,
tenant_id: str | None = None,
) -> RetrievalStrategyRollbackResponse:
"""
[AC-AISVC-RES-07] Rollback to previous or default strategy.
Args:
operator: Operator who initiated the rollback.
tenant_id: Tenant ID for audit.
Returns:
RetrievalStrategyRollbackResponse with current and rollback status.
"""
current_status = self.get_current_status()
rollback_to_strategy = self._state.previous_strategy or StrategyType.DEFAULT
rollback_to_react_mode = self._state.previous_react_mode or ReactMode.NON_REACT
old_strategy = self._state.active_strategy
old_react_mode = self._state.react_mode
self._state.active_strategy = rollback_to_strategy
self._state.react_mode = rollback_to_react_mode
self._state.rollout_mode = RolloutMode.OFF
self._state.rollout_percentage = 0.0
self._state.rollout_allowlist = []
rollback_status = self.get_current_status()
rollback_record = {
"timestamp": datetime.utcnow().isoformat(),
"from_strategy": old_strategy.value,
"to_strategy": rollback_to_strategy.value,
"operator": operator,
}
self._state.switch_history.append(rollback_record)
logger.info(
f"[AC-AISVC-RES-07] Strategy rolled back: {old_strategy.value} -> "
f"{rollback_to_strategy.value}, react_mode={rollback_to_react_mode.value}"
)
if self._audit_callback:
self._audit_callback(
operation="rollback",
previous_strategy=old_strategy.value,
new_strategy=rollback_to_strategy.value,
previous_react_mode=old_react_mode.value,
new_react_mode=rollback_to_react_mode.value,
reason="Manual rollback",
operator=operator,
tenant_id=tenant_id,
)
if self._metrics_callback:
self._metrics_callback("strategy_rollback", {
"from_strategy": old_strategy.value,
"to_strategy": rollback_to_strategy.value,
})
return RetrievalStrategyRollbackResponse(
current=current_status,
rollback_to=rollback_status,
)
def should_use_enhanced_strategy(self, tenant_id: str | None = None) -> bool:
"""
[AC-AISVC-RES-03] Determine if enhanced strategy should be used based on rollout config.
Args:
tenant_id: Tenant ID for allowlist check.
Returns:
True if enhanced strategy should be used.
"""
if self._state.active_strategy == StrategyType.DEFAULT:
return False
if self._state.rollout_mode == RolloutMode.OFF:
return self._state.active_strategy == StrategyType.ENHANCED
if self._state.rollout_mode == RolloutMode.ALLOWLIST:
if tenant_id and tenant_id in self._state.rollout_allowlist:
return True
return False
if self._state.rollout_mode == RolloutMode.PERCENTAGE:
import random
return random.random() * 100 < self._state.rollout_percentage
return False
def get_route_mode(
self,
query: str,
confidence: float | None = None,
) -> str:
"""
[AC-AISVC-RES-09~15] Determine route mode based on query and confidence.
Args:
query: User query.
confidence: Confidence score from metadata inference.
Returns:
Route mode: "direct", "react", or "auto".
"""
if self._state.react_mode == ReactMode.REACT:
return "react"
elif self._state.react_mode == ReactMode.NON_REACT:
return "direct"
else:
return self._auto_route(query, confidence)
def _auto_route(self, query: str, confidence: float | None = None) -> str:
"""
[AC-AISVC-RES-11~14] Auto route based on query complexity and confidence.
"""
query_length = len(query)
has_multiple_conditions = "" in query or "" in query or "以及" in query
low_confidence_threshold = 0.5
short_query_threshold = 20
if confidence is not None and confidence < low_confidence_threshold:
logger.info(
f"[AC-AISVC-RES-13] Auto route to react: low confidence={confidence}"
)
return "react"
if has_multiple_conditions:
logger.info(
f"[AC-AISVC-RES-13] Auto route to react: multiple conditions detected"
)
return "react"
if query_length < short_query_threshold and confidence and confidence > 0.7:
logger.info(
f"[AC-AISVC-RES-12] Auto route to direct: short query, high confidence"
)
return "direct"
return "direct"
def get_switch_history(self, limit: int = 10) -> list[dict[str, Any]]:
"""
Get recent switch history.
Args:
limit: Maximum number of records to return.
Returns:
List of switch records.
"""
return self._state.switch_history[-limit:]
_strategy_service: RetrievalStrategyService | None = None
def get_strategy_service() -> RetrievalStrategyService:
"""Get or create RetrievalStrategyService instance."""
global _strategy_service
if _strategy_service is None:
_strategy_service = RetrievalStrategyService()
return _strategy_service