ai-robot-core/ai-service/app/services/retrieval/strategy/config.py

202 lines
6.4 KiB
Python

"""
Retrieval Strategy Configuration.
[AC-AISVC-RES-01~15] 检索策略配置模型。
"""
import random
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
class StrategyType(str, Enum):
"""策略类型。"""
DEFAULT = "default"
ENHANCED = "enhanced"
class RuntimeMode(str, Enum):
"""运行时模式。"""
DIRECT = "direct"
REACT = "react"
AUTO = "auto"
class FilterMode(str, Enum):
"""过滤模式。"""
HARD = "hard"
SOFT = "soft"
NONE = "none"
@dataclass
class GrayscaleConfig:
"""灰度发布配置。【AC-AISVC-RES-03】"""
enabled: bool = False
percentage: float = 0.0
allowlist: list[str] = field(default_factory=list)
def should_use_enhanced(self, tenant_id: str, user_id: str | None = None) -> bool:
"""判断是否应该使用增强策略。"""
if not self.enabled:
return False
if tenant_id in self.allowlist or (user_id and user_id in self.allowlist):
return True
return random.random() * 100 < self.percentage
@dataclass
class HybridRetrievalConfig:
"""混合检索配置。"""
dense_weight: float = 0.7
keyword_weight: float = 0.3
rrf_k: int = 60
enable_keyword: bool = True
keyword_top_k_multiplier: int = 2
@dataclass
class RerankerConfig:
"""重排器配置。【AC-AISVC-RES-08】"""
enabled: bool = False
model: str = "cross-encoder"
top_k_after_rerank: int = 5
min_score_threshold: float = 0.3
@dataclass
class ModeRouterConfig:
"""模式路由配置。【AC-AISVC-RES-09~15】"""
runtime_mode: RuntimeMode = RuntimeMode.AUTO
react_trigger_confidence_threshold: float = 0.6
react_trigger_complexity_score: float = 0.5
react_max_steps: int = 5
direct_fallback_on_low_confidence: bool = True
short_query_threshold: int = 20
def should_use_react(
self,
query: str,
confidence: float | None = None,
complexity_score: float | None = None,
) -> bool:
"""判断是否应该使用 ReAct 模式。【AC-AISVC-RES-11~13】"""
if self.runtime_mode == RuntimeMode.REACT:
return True
if self.runtime_mode == RuntimeMode.DIRECT:
return False
if len(query) <= self.short_query_threshold and confidence and confidence >= self.react_trigger_confidence_threshold:
return False
if complexity_score and complexity_score >= self.react_trigger_complexity_score:
return True
if confidence and confidence < self.react_trigger_confidence_threshold:
return True
return False
@dataclass
class MetadataInferenceConfig:
"""元数据推断配置。"""
enabled: bool = True
confidence_high_threshold: float = 0.8
confidence_low_threshold: float = 0.5
default_filter_mode: FilterMode = FilterMode.SOFT
cache_ttl_seconds: int = 300
def determine_filter_mode(self, confidence: float | None) -> FilterMode:
"""根据置信度确定过滤模式。"""
if confidence is None:
return FilterMode.NONE
if confidence >= self.confidence_high_threshold:
return FilterMode.HARD
if confidence >= self.confidence_low_threshold:
return FilterMode.SOFT
return FilterMode.NONE
@dataclass
class PipelineConfig:
"""Pipeline 配置。"""
top_k: int = 5
score_threshold: float = 0.01
min_hits: int = 1
two_stage_enabled: bool = True
two_stage_expand_factor: int = 10
hybrid: HybridRetrievalConfig = field(default_factory=HybridRetrievalConfig)
@dataclass
class RetrievalStrategyConfig:
"""检索策略顶层配置。【AC-AISVC-RES-01~15】"""
active_strategy: StrategyType = StrategyType.DEFAULT
grayscale: GrayscaleConfig = field(default_factory=GrayscaleConfig)
pipeline: PipelineConfig = field(default_factory=PipelineConfig)
reranker: RerankerConfig = field(default_factory=RerankerConfig)
mode_router: ModeRouterConfig = field(default_factory=ModeRouterConfig)
metadata_inference: MetadataInferenceConfig = field(default_factory=MetadataInferenceConfig)
performance_thresholds: dict[str, float] = field(default_factory=lambda: {
"max_latency_ms": 2000.0,
"min_success_rate": 0.95,
"max_error_rate": 0.05,
})
def is_enhanced_enabled(self, tenant_id: str, user_id: str | None = None) -> bool:
"""判断是否启用增强策略。"""
if self.active_strategy == StrategyType.ENHANCED:
return True
return self.grayscale.should_use_enhanced(tenant_id, user_id)
def to_dict(self) -> dict[str, Any]:
"""转换为字典。"""
return {
"active_strategy": self.active_strategy.value,
"grayscale": {
"enabled": self.grayscale.enabled,
"percentage": self.grayscale.percentage,
"allowlist": self.grayscale.allowlist,
},
"pipeline": {
"top_k": self.pipeline.top_k,
"score_threshold": self.pipeline.score_threshold,
"min_hits": self.pipeline.min_hits,
"two_stage_enabled": self.pipeline.two_stage_enabled,
},
"reranker": {
"enabled": self.reranker.enabled,
"model": self.reranker.model,
"top_k_after_rerank": self.reranker.top_k_after_rerank,
},
"mode_router": {
"runtime_mode": self.mode_router.runtime_mode.value,
"react_trigger_confidence_threshold": self.mode_router.react_trigger_confidence_threshold,
},
"metadata_inference": {
"enabled": self.metadata_inference.enabled,
"confidence_high_threshold": self.metadata_inference.confidence_high_threshold,
},
"performance_thresholds": self.performance_thresholds,
}
_global_config: RetrievalStrategyConfig | None = None
def get_strategy_config() -> RetrievalStrategyConfig:
"""获取全局策略配置。"""
global _global_config
if _global_config is None:
_global_config = RetrievalStrategyConfig()
return _global_config
def set_strategy_config(config: RetrievalStrategyConfig) -> None:
"""设置全局策略配置。"""
global _global_config
_global_config = config