202 lines
6.4 KiB
Python
202 lines
6.4 KiB
Python
"""
|
|
Retrieval Strategy Configuration.
|
|
[AC-AISVC-RES-01~15] 检索策略配置模型。
|
|
"""
|
|
|
|
import random
|
|
from dataclasses import dataclass, field
|
|
from enum import Enum
|
|
from typing import Any
|
|
|
|
|
|
class StrategyType(str, Enum):
|
|
"""策略类型。"""
|
|
DEFAULT = "default"
|
|
ENHANCED = "enhanced"
|
|
|
|
|
|
class RuntimeMode(str, Enum):
|
|
"""运行时模式。"""
|
|
DIRECT = "direct"
|
|
REACT = "react"
|
|
AUTO = "auto"
|
|
|
|
|
|
class FilterMode(str, Enum):
|
|
"""过滤模式。"""
|
|
HARD = "hard"
|
|
SOFT = "soft"
|
|
NONE = "none"
|
|
|
|
|
|
@dataclass
|
|
class GrayscaleConfig:
|
|
"""灰度发布配置。【AC-AISVC-RES-03】"""
|
|
enabled: bool = False
|
|
percentage: float = 0.0
|
|
allowlist: list[str] = field(default_factory=list)
|
|
|
|
def should_use_enhanced(self, tenant_id: str, user_id: str | None = None) -> bool:
|
|
"""判断是否应该使用增强策略。"""
|
|
if not self.enabled:
|
|
return False
|
|
|
|
if tenant_id in self.allowlist or (user_id and user_id in self.allowlist):
|
|
return True
|
|
|
|
return random.random() * 100 < self.percentage
|
|
|
|
|
|
@dataclass
|
|
class HybridRetrievalConfig:
|
|
"""混合检索配置。"""
|
|
dense_weight: float = 0.7
|
|
keyword_weight: float = 0.3
|
|
rrf_k: int = 60
|
|
enable_keyword: bool = True
|
|
keyword_top_k_multiplier: int = 2
|
|
|
|
|
|
@dataclass
|
|
class RerankerConfig:
|
|
"""重排器配置。【AC-AISVC-RES-08】"""
|
|
enabled: bool = False
|
|
model: str = "cross-encoder"
|
|
top_k_after_rerank: int = 5
|
|
min_score_threshold: float = 0.3
|
|
|
|
|
|
@dataclass
|
|
class ModeRouterConfig:
|
|
"""模式路由配置。【AC-AISVC-RES-09~15】"""
|
|
runtime_mode: RuntimeMode = RuntimeMode.DIRECT
|
|
react_trigger_confidence_threshold: float = 0.6
|
|
react_trigger_complexity_score: float = 0.5
|
|
react_max_steps: int = 5
|
|
direct_fallback_on_low_confidence: bool = True
|
|
short_query_threshold: int = 20
|
|
|
|
def should_use_react(
|
|
self,
|
|
query: str,
|
|
confidence: float | None = None,
|
|
complexity_score: float | None = None,
|
|
) -> bool:
|
|
"""判断是否应该使用 ReAct 模式。【AC-AISVC-RES-11~13】"""
|
|
if self.runtime_mode == RuntimeMode.REACT:
|
|
return True
|
|
if self.runtime_mode == RuntimeMode.DIRECT:
|
|
return False
|
|
|
|
if len(query) <= self.short_query_threshold and confidence and confidence >= self.react_trigger_confidence_threshold:
|
|
return False
|
|
|
|
if complexity_score and complexity_score >= self.react_trigger_complexity_score:
|
|
return True
|
|
|
|
if confidence and confidence < self.react_trigger_confidence_threshold:
|
|
return True
|
|
|
|
return False
|
|
|
|
|
|
@dataclass
|
|
class MetadataInferenceConfig:
|
|
"""元数据推断配置。"""
|
|
enabled: bool = True
|
|
confidence_high_threshold: float = 0.8
|
|
confidence_low_threshold: float = 0.5
|
|
default_filter_mode: FilterMode = FilterMode.SOFT
|
|
cache_ttl_seconds: int = 300
|
|
|
|
def determine_filter_mode(self, confidence: float | None) -> FilterMode:
|
|
"""根据置信度确定过滤模式。"""
|
|
if confidence is None:
|
|
return FilterMode.NONE
|
|
if confidence >= self.confidence_high_threshold:
|
|
return FilterMode.HARD
|
|
if confidence >= self.confidence_low_threshold:
|
|
return FilterMode.SOFT
|
|
return FilterMode.NONE
|
|
|
|
|
|
@dataclass
|
|
class PipelineConfig:
|
|
"""Pipeline 配置。"""
|
|
top_k: int = 5
|
|
score_threshold: float = 0.01
|
|
min_hits: int = 1
|
|
two_stage_enabled: bool = True
|
|
two_stage_expand_factor: int = 10
|
|
hybrid: HybridRetrievalConfig = field(default_factory=HybridRetrievalConfig)
|
|
|
|
|
|
@dataclass
|
|
class RetrievalStrategyConfig:
|
|
"""检索策略顶层配置。【AC-AISVC-RES-01~15】"""
|
|
active_strategy: StrategyType = StrategyType.DEFAULT
|
|
grayscale: GrayscaleConfig = field(default_factory=GrayscaleConfig)
|
|
pipeline: PipelineConfig = field(default_factory=PipelineConfig)
|
|
reranker: RerankerConfig = field(default_factory=RerankerConfig)
|
|
mode_router: ModeRouterConfig = field(default_factory=ModeRouterConfig)
|
|
metadata_inference: MetadataInferenceConfig = field(default_factory=MetadataInferenceConfig)
|
|
performance_thresholds: dict[str, float] = field(default_factory=lambda: {
|
|
"max_latency_ms": 2000.0,
|
|
"min_success_rate": 0.95,
|
|
"max_error_rate": 0.05,
|
|
})
|
|
|
|
def is_enhanced_enabled(self, tenant_id: str, user_id: str | None = None) -> bool:
|
|
"""判断是否启用增强策略。"""
|
|
if self.active_strategy == StrategyType.ENHANCED:
|
|
return True
|
|
return self.grayscale.should_use_enhanced(tenant_id, user_id)
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""转换为字典。"""
|
|
return {
|
|
"active_strategy": self.active_strategy.value,
|
|
"grayscale": {
|
|
"enabled": self.grayscale.enabled,
|
|
"percentage": self.grayscale.percentage,
|
|
"allowlist": self.grayscale.allowlist,
|
|
},
|
|
"pipeline": {
|
|
"top_k": self.pipeline.top_k,
|
|
"score_threshold": self.pipeline.score_threshold,
|
|
"min_hits": self.pipeline.min_hits,
|
|
"two_stage_enabled": self.pipeline.two_stage_enabled,
|
|
},
|
|
"reranker": {
|
|
"enabled": self.reranker.enabled,
|
|
"model": self.reranker.model,
|
|
"top_k_after_rerank": self.reranker.top_k_after_rerank,
|
|
},
|
|
"mode_router": {
|
|
"runtime_mode": self.mode_router.runtime_mode.value,
|
|
"react_trigger_confidence_threshold": self.mode_router.react_trigger_confidence_threshold,
|
|
},
|
|
"metadata_inference": {
|
|
"enabled": self.metadata_inference.enabled,
|
|
"confidence_high_threshold": self.metadata_inference.confidence_high_threshold,
|
|
},
|
|
"performance_thresholds": self.performance_thresholds,
|
|
}
|
|
|
|
|
|
_global_config: RetrievalStrategyConfig | None = None
|
|
|
|
|
|
def get_strategy_config() -> RetrievalStrategyConfig:
|
|
"""获取全局策略配置。"""
|
|
global _global_config
|
|
if _global_config is None:
|
|
_global_config = RetrievalStrategyConfig()
|
|
return _global_config
|
|
|
|
|
|
def set_strategy_config(config: RetrievalStrategyConfig) -> None:
|
|
"""设置全局策略配置。"""
|
|
global _global_config
|
|
_global_config = config
|