[AC-AISVC-RES-01~15] feat(retrieval): 实现检索策略Pipeline模块
- 新增策略配置模型 (config.py) - GrayscaleConfig: 灰度发布配置 - ModeRouterConfig: 模式路由配置 - MetadataInferenceConfig: 元数据推断配置 - 新增 Pipeline 实现 - DefaultPipeline: 复用现有 OptimizedRetriever 逻辑 - EnhancedPipeline: Dense + Keyword + RRF 组合检索 - 新增路由器 - StrategyRouter: 策略路由器(default/enhanced) - ModeRouter: 模式路由器(direct/react/auto) - 新增 RollbackManager: 回退与审计管理器 - 新增 MetadataInferenceService: 元数据推断统一入口 - 新增单元测试 (51 passed)
This commit is contained in:
parent
9f28498b97
commit
7027097513
|
|
@ -0,0 +1,102 @@
|
||||||
|
"""
|
||||||
|
Retrieval Strategy Module for AI Service.
|
||||||
|
[AC-AISVC-RES-01~15] 策略化检索与嵌入模块。
|
||||||
|
|
||||||
|
核心组件:
|
||||||
|
- RetrievalStrategyConfig: 策略配置模型
|
||||||
|
- BasePipeline: Pipeline 抽象基类
|
||||||
|
- DefaultPipeline: 默认策略(复用现有逻辑)
|
||||||
|
- EnhancedPipeline: 增强策略(新端到端流程)
|
||||||
|
- MetadataInferenceService: 元数据推断统一入口
|
||||||
|
- StrategyRouter: 策略路由器
|
||||||
|
- ModeRouter: 模式路由器(direct/react/auto)
|
||||||
|
- RollbackManager: 回退管理器
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.services.retrieval.strategy.config import (
|
||||||
|
FilterMode,
|
||||||
|
GrayscaleConfig,
|
||||||
|
HybridRetrievalConfig,
|
||||||
|
MetadataInferenceConfig,
|
||||||
|
ModeRouterConfig,
|
||||||
|
PipelineConfig,
|
||||||
|
RerankerConfig,
|
||||||
|
RetrievalStrategyConfig,
|
||||||
|
RuntimeMode,
|
||||||
|
StrategyType,
|
||||||
|
get_strategy_config,
|
||||||
|
set_strategy_config,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.strategy.default_pipeline import (
|
||||||
|
DefaultPipeline,
|
||||||
|
get_default_pipeline,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.strategy.enhanced_pipeline import (
|
||||||
|
EnhancedPipeline,
|
||||||
|
get_enhanced_pipeline,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.strategy.metadata_inference import (
|
||||||
|
InferenceContext,
|
||||||
|
InferenceResult,
|
||||||
|
MetadataInferenceService,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.strategy.mode_router import (
|
||||||
|
ModeDecision,
|
||||||
|
ModeRouter,
|
||||||
|
get_mode_router,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.strategy.pipeline_base import (
|
||||||
|
BasePipeline,
|
||||||
|
MetadataFilterResult,
|
||||||
|
PipelineContext,
|
||||||
|
PipelineResult,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.strategy.rollback_manager import (
|
||||||
|
AuditLog,
|
||||||
|
RollbackManager,
|
||||||
|
RollbackResult,
|
||||||
|
RollbackTrigger,
|
||||||
|
get_rollback_manager,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.strategy.strategy_router import (
|
||||||
|
RoutingDecision,
|
||||||
|
StrategyRouter,
|
||||||
|
get_strategy_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BasePipeline",
|
||||||
|
"PipelineContext",
|
||||||
|
"PipelineResult",
|
||||||
|
"MetadataFilterResult",
|
||||||
|
"DefaultPipeline",
|
||||||
|
"get_default_pipeline",
|
||||||
|
"EnhancedPipeline",
|
||||||
|
"get_enhanced_pipeline",
|
||||||
|
"RetrievalStrategyConfig",
|
||||||
|
"GrayscaleConfig",
|
||||||
|
"PipelineConfig",
|
||||||
|
"RerankerConfig",
|
||||||
|
"ModeRouterConfig",
|
||||||
|
"HybridRetrievalConfig",
|
||||||
|
"MetadataInferenceConfig",
|
||||||
|
"StrategyType",
|
||||||
|
"FilterMode",
|
||||||
|
"RuntimeMode",
|
||||||
|
"get_strategy_config",
|
||||||
|
"set_strategy_config",
|
||||||
|
"MetadataInferenceService",
|
||||||
|
"InferenceContext",
|
||||||
|
"InferenceResult",
|
||||||
|
"StrategyRouter",
|
||||||
|
"RoutingDecision",
|
||||||
|
"get_strategy_router",
|
||||||
|
"ModeRouter",
|
||||||
|
"ModeDecision",
|
||||||
|
"get_mode_router",
|
||||||
|
"RollbackManager",
|
||||||
|
"RollbackResult",
|
||||||
|
"RollbackTrigger",
|
||||||
|
"AuditLog",
|
||||||
|
"get_rollback_manager",
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,201 @@
|
||||||
|
"""
|
||||||
|
Retrieval Strategy Configuration.
|
||||||
|
[AC-AISVC-RES-01~15] 检索策略配置模型。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import random
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class StrategyType(str, Enum):
|
||||||
|
"""策略类型。"""
|
||||||
|
DEFAULT = "default"
|
||||||
|
ENHANCED = "enhanced"
|
||||||
|
|
||||||
|
|
||||||
|
class RuntimeMode(str, Enum):
|
||||||
|
"""运行时模式。"""
|
||||||
|
DIRECT = "direct"
|
||||||
|
REACT = "react"
|
||||||
|
AUTO = "auto"
|
||||||
|
|
||||||
|
|
||||||
|
class FilterMode(str, Enum):
|
||||||
|
"""过滤模式。"""
|
||||||
|
HARD = "hard"
|
||||||
|
SOFT = "soft"
|
||||||
|
NONE = "none"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class GrayscaleConfig:
|
||||||
|
"""灰度发布配置。【AC-AISVC-RES-03】"""
|
||||||
|
enabled: bool = False
|
||||||
|
percentage: float = 0.0
|
||||||
|
allowlist: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
def should_use_enhanced(self, tenant_id: str, user_id: str | None = None) -> bool:
|
||||||
|
"""判断是否应该使用增强策略。"""
|
||||||
|
if not self.enabled:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if tenant_id in self.allowlist or (user_id and user_id in self.allowlist):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return random.random() * 100 < self.percentage
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class HybridRetrievalConfig:
|
||||||
|
"""混合检索配置。"""
|
||||||
|
dense_weight: float = 0.7
|
||||||
|
keyword_weight: float = 0.3
|
||||||
|
rrf_k: int = 60
|
||||||
|
enable_keyword: bool = True
|
||||||
|
keyword_top_k_multiplier: int = 2
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RerankerConfig:
|
||||||
|
"""重排器配置。【AC-AISVC-RES-08】"""
|
||||||
|
enabled: bool = False
|
||||||
|
model: str = "cross-encoder"
|
||||||
|
top_k_after_rerank: int = 5
|
||||||
|
min_score_threshold: float = 0.3
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModeRouterConfig:
|
||||||
|
"""模式路由配置。【AC-AISVC-RES-09~15】"""
|
||||||
|
runtime_mode: RuntimeMode = RuntimeMode.DIRECT
|
||||||
|
react_trigger_confidence_threshold: float = 0.6
|
||||||
|
react_trigger_complexity_score: float = 0.5
|
||||||
|
react_max_steps: int = 5
|
||||||
|
direct_fallback_on_low_confidence: bool = True
|
||||||
|
short_query_threshold: int = 20
|
||||||
|
|
||||||
|
def should_use_react(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
confidence: float | None = None,
|
||||||
|
complexity_score: float | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""判断是否应该使用 ReAct 模式。【AC-AISVC-RES-11~13】"""
|
||||||
|
if self.runtime_mode == RuntimeMode.REACT:
|
||||||
|
return True
|
||||||
|
if self.runtime_mode == RuntimeMode.DIRECT:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if len(query) <= self.short_query_threshold and confidence and confidence >= self.react_trigger_confidence_threshold:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if complexity_score and complexity_score >= self.react_trigger_complexity_score:
|
||||||
|
return True
|
||||||
|
|
||||||
|
if confidence and confidence < self.react_trigger_confidence_threshold:
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MetadataInferenceConfig:
|
||||||
|
"""元数据推断配置。"""
|
||||||
|
enabled: bool = True
|
||||||
|
confidence_high_threshold: float = 0.8
|
||||||
|
confidence_low_threshold: float = 0.5
|
||||||
|
default_filter_mode: FilterMode = FilterMode.SOFT
|
||||||
|
cache_ttl_seconds: int = 300
|
||||||
|
|
||||||
|
def determine_filter_mode(self, confidence: float | None) -> FilterMode:
|
||||||
|
"""根据置信度确定过滤模式。"""
|
||||||
|
if confidence is None:
|
||||||
|
return FilterMode.NONE
|
||||||
|
if confidence >= self.confidence_high_threshold:
|
||||||
|
return FilterMode.HARD
|
||||||
|
if confidence >= self.confidence_low_threshold:
|
||||||
|
return FilterMode.SOFT
|
||||||
|
return FilterMode.NONE
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineConfig:
|
||||||
|
"""Pipeline 配置。"""
|
||||||
|
top_k: int = 5
|
||||||
|
score_threshold: float = 0.01
|
||||||
|
min_hits: int = 1
|
||||||
|
two_stage_enabled: bool = True
|
||||||
|
two_stage_expand_factor: int = 10
|
||||||
|
hybrid: HybridRetrievalConfig = field(default_factory=HybridRetrievalConfig)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetrievalStrategyConfig:
|
||||||
|
"""检索策略顶层配置。【AC-AISVC-RES-01~15】"""
|
||||||
|
active_strategy: StrategyType = StrategyType.DEFAULT
|
||||||
|
grayscale: GrayscaleConfig = field(default_factory=GrayscaleConfig)
|
||||||
|
pipeline: PipelineConfig = field(default_factory=PipelineConfig)
|
||||||
|
reranker: RerankerConfig = field(default_factory=RerankerConfig)
|
||||||
|
mode_router: ModeRouterConfig = field(default_factory=ModeRouterConfig)
|
||||||
|
metadata_inference: MetadataInferenceConfig = field(default_factory=MetadataInferenceConfig)
|
||||||
|
performance_thresholds: dict[str, float] = field(default_factory=lambda: {
|
||||||
|
"max_latency_ms": 2000.0,
|
||||||
|
"min_success_rate": 0.95,
|
||||||
|
"max_error_rate": 0.05,
|
||||||
|
})
|
||||||
|
|
||||||
|
def is_enhanced_enabled(self, tenant_id: str, user_id: str | None = None) -> bool:
|
||||||
|
"""判断是否启用增强策略。"""
|
||||||
|
if self.active_strategy == StrategyType.ENHANCED:
|
||||||
|
return True
|
||||||
|
return self.grayscale.should_use_enhanced(tenant_id, user_id)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict[str, Any]:
|
||||||
|
"""转换为字典。"""
|
||||||
|
return {
|
||||||
|
"active_strategy": self.active_strategy.value,
|
||||||
|
"grayscale": {
|
||||||
|
"enabled": self.grayscale.enabled,
|
||||||
|
"percentage": self.grayscale.percentage,
|
||||||
|
"allowlist": self.grayscale.allowlist,
|
||||||
|
},
|
||||||
|
"pipeline": {
|
||||||
|
"top_k": self.pipeline.top_k,
|
||||||
|
"score_threshold": self.pipeline.score_threshold,
|
||||||
|
"min_hits": self.pipeline.min_hits,
|
||||||
|
"two_stage_enabled": self.pipeline.two_stage_enabled,
|
||||||
|
},
|
||||||
|
"reranker": {
|
||||||
|
"enabled": self.reranker.enabled,
|
||||||
|
"model": self.reranker.model,
|
||||||
|
"top_k_after_rerank": self.reranker.top_k_after_rerank,
|
||||||
|
},
|
||||||
|
"mode_router": {
|
||||||
|
"runtime_mode": self.mode_router.runtime_mode.value,
|
||||||
|
"react_trigger_confidence_threshold": self.mode_router.react_trigger_confidence_threshold,
|
||||||
|
},
|
||||||
|
"metadata_inference": {
|
||||||
|
"enabled": self.metadata_inference.enabled,
|
||||||
|
"confidence_high_threshold": self.metadata_inference.confidence_high_threshold,
|
||||||
|
},
|
||||||
|
"performance_thresholds": self.performance_thresholds,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
_global_config: RetrievalStrategyConfig | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_strategy_config() -> RetrievalStrategyConfig:
|
||||||
|
"""获取全局策略配置。"""
|
||||||
|
global _global_config
|
||||||
|
if _global_config is None:
|
||||||
|
_global_config = RetrievalStrategyConfig()
|
||||||
|
return _global_config
|
||||||
|
|
||||||
|
|
||||||
|
def set_strategy_config(config: RetrievalStrategyConfig) -> None:
|
||||||
|
"""设置全局策略配置。"""
|
||||||
|
global _global_config
|
||||||
|
_global_config = config
|
||||||
|
|
@ -0,0 +1,117 @@
|
||||||
|
"""
|
||||||
|
Default Pipeline.
|
||||||
|
[AC-AISVC-RES-01] 默认策略 Pipeline,复用现有逻辑。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.services.retrieval.base import RetrievalContext, RetrievalResult
|
||||||
|
from app.services.retrieval.optimized_retriever import OptimizedRetriever, get_optimized_retriever
|
||||||
|
from app.services.retrieval.strategy.config import PipelineConfig
|
||||||
|
from app.services.retrieval.strategy.pipeline_base import (
|
||||||
|
BasePipeline,
|
||||||
|
PipelineContext,
|
||||||
|
PipelineResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultPipeline(BasePipeline):
|
||||||
|
"""
|
||||||
|
默认策略 Pipeline。【AC-AISVC-RES-01】
|
||||||
|
|
||||||
|
复用现有 OptimizedRetriever 逻辑,保持线上行为不变。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PipelineConfig | None = None,
|
||||||
|
optimized_retriever: OptimizedRetriever | None = None,
|
||||||
|
):
|
||||||
|
self._config = config or PipelineConfig()
|
||||||
|
self._optimized_retriever = optimized_retriever
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "default_pipeline"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "默认检索策略,复用现有 OptimizedRetriever 逻辑。"
|
||||||
|
|
||||||
|
async def _get_retriever(self) -> OptimizedRetriever:
|
||||||
|
if self._optimized_retriever is None:
|
||||||
|
self._optimized_retriever = await get_optimized_retriever()
|
||||||
|
return self._optimized_retriever
|
||||||
|
|
||||||
|
async def retrieve(self, ctx: PipelineContext) -> PipelineResult:
|
||||||
|
"""执行默认检索流程。【AC-AISVC-RES-01】"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[DefaultPipeline] Starting retrieval: tenant={ctx.tenant_id}, "
|
||||||
|
f"query={ctx.query[:50]}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
retriever = await self._get_retriever()
|
||||||
|
|
||||||
|
metadata_filter = None
|
||||||
|
if ctx.metadata_filter:
|
||||||
|
metadata_filter = ctx.metadata_filter.filter_dict
|
||||||
|
|
||||||
|
retrieval_ctx = RetrievalContext(
|
||||||
|
tenant_id=ctx.tenant_id,
|
||||||
|
query=ctx.query,
|
||||||
|
session_id=ctx.session_id,
|
||||||
|
metadata_filter=metadata_filter,
|
||||||
|
kb_ids=ctx.kb_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await retriever.retrieve(retrieval_ctx)
|
||||||
|
|
||||||
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[DefaultPipeline] Retrieval completed: hits={len(result.hits)}, "
|
||||||
|
f"latency_ms={latency_ms:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return PipelineResult(
|
||||||
|
retrieval_result=result,
|
||||||
|
pipeline_name=self.name,
|
||||||
|
metadata_filter_applied=metadata_filter is not None,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
diagnostics={
|
||||||
|
"retriever": "OptimizedRetriever",
|
||||||
|
**(result.diagnostics or {}),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
|
logger.error(f"[DefaultPipeline] Retrieval error: {e}", exc_info=True)
|
||||||
|
return self._create_empty_result(ctx, error=str(e), latency_ms=latency_ms)
|
||||||
|
|
||||||
|
async def health_check(self) -> bool:
|
||||||
|
"""健康检查。"""
|
||||||
|
try:
|
||||||
|
retriever = await self._get_retriever()
|
||||||
|
return await retriever.health_check()
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[DefaultPipeline] Health check failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
_default_pipeline: DefaultPipeline | None = None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_default_pipeline() -> DefaultPipeline:
|
||||||
|
"""获取 DefaultPipeline 单例。"""
|
||||||
|
global _default_pipeline
|
||||||
|
if _default_pipeline is None:
|
||||||
|
_default_pipeline = DefaultPipeline()
|
||||||
|
return _default_pipeline
|
||||||
|
|
@ -0,0 +1,364 @@
|
||||||
|
"""
|
||||||
|
Enhanced Pipeline.
|
||||||
|
[AC-AISVC-RES-02] 增强策略 Pipeline,新端到端流程。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from app.core.config import get_settings
|
||||||
|
from app.core.qdrant_client import QdrantClient, get_qdrant_client
|
||||||
|
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
|
||||||
|
from app.services.retrieval.base import RetrievalHit, RetrievalResult
|
||||||
|
from app.services.retrieval.optimized_retriever import RRFCombiner
|
||||||
|
from app.services.retrieval.strategy.config import (
|
||||||
|
HybridRetrievalConfig,
|
||||||
|
PipelineConfig,
|
||||||
|
RerankerConfig,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.strategy.pipeline_base import (
|
||||||
|
BasePipeline,
|
||||||
|
PipelineContext,
|
||||||
|
PipelineResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
settings = get_settings()
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetrievalCandidate:
|
||||||
|
"""检索候选结果。"""
|
||||||
|
id: str
|
||||||
|
text: str
|
||||||
|
score: float
|
||||||
|
vector_score: float = 0.0
|
||||||
|
keyword_score: float = 0.0
|
||||||
|
metadata: dict[str, Any] = None
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
if self.metadata is None:
|
||||||
|
self.metadata = {}
|
||||||
|
|
||||||
|
|
||||||
|
class EnhancedPipeline(BasePipeline):
|
||||||
|
"""
|
||||||
|
增强策略 Pipeline。【AC-AISVC-RES-02】
|
||||||
|
|
||||||
|
新端到端流程:
|
||||||
|
1. Dense 向量检索
|
||||||
|
2. Keyword 关键词检索
|
||||||
|
3. RRF 融合排序
|
||||||
|
4. 可选重排
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: PipelineConfig | None = None,
|
||||||
|
reranker_config: RerankerConfig | None = None,
|
||||||
|
qdrant_client: QdrantClient | None = None,
|
||||||
|
):
|
||||||
|
self._config = config or PipelineConfig()
|
||||||
|
self._reranker_config = reranker_config or RerankerConfig()
|
||||||
|
self._qdrant_client = qdrant_client
|
||||||
|
self._rrf_combiner = RRFCombiner(k=self._config.hybrid.rrf_k)
|
||||||
|
self._embedding_provider: NomicEmbeddingProvider | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
return "enhanced_pipeline"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def description(self) -> str:
|
||||||
|
return "增强检索策略,支持 Dense + Keyword + RRF 组合检索。"
|
||||||
|
|
||||||
|
async def _get_client(self) -> QdrantClient:
|
||||||
|
if self._qdrant_client is None:
|
||||||
|
self._qdrant_client = await get_qdrant_client()
|
||||||
|
return self._qdrant_client
|
||||||
|
|
||||||
|
async def _get_embedding_provider(self) -> NomicEmbeddingProvider:
|
||||||
|
if self._embedding_provider is None:
|
||||||
|
from app.services.embedding.factory import get_embedding_config_manager
|
||||||
|
manager = get_embedding_config_manager()
|
||||||
|
provider = await manager.get_provider()
|
||||||
|
if isinstance(provider, NomicEmbeddingProvider):
|
||||||
|
self._embedding_provider = provider
|
||||||
|
else:
|
||||||
|
self._embedding_provider = NomicEmbeddingProvider(
|
||||||
|
base_url=settings.ollama_base_url,
|
||||||
|
model=settings.ollama_embedding_model,
|
||||||
|
dimension=settings.qdrant_vector_size,
|
||||||
|
)
|
||||||
|
return self._embedding_provider
|
||||||
|
|
||||||
|
async def retrieve(self, ctx: PipelineContext) -> PipelineResult:
|
||||||
|
"""执行增强检索流程。【AC-AISVC-RES-02】"""
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[EnhancedPipeline] Starting retrieval: tenant={ctx.tenant_id}, "
|
||||||
|
f"query={ctx.query[:50]}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider = await self._get_embedding_provider()
|
||||||
|
embedding_result = await provider.embed_query(ctx.query)
|
||||||
|
|
||||||
|
candidates = await self._hybrid_retrieve(
|
||||||
|
tenant_id=ctx.tenant_id,
|
||||||
|
query=ctx.query,
|
||||||
|
embedding_result=embedding_result,
|
||||||
|
metadata_filter=ctx.metadata_filter.filter_dict if ctx.metadata_filter else None,
|
||||||
|
kb_ids=ctx.kb_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._reranker_config.enabled and ctx.use_reranker:
|
||||||
|
candidates = await self._rerank(
|
||||||
|
candidates=candidates,
|
||||||
|
query=ctx.query,
|
||||||
|
)
|
||||||
|
|
||||||
|
top_k = self._config.top_k
|
||||||
|
final_candidates = candidates[:top_k]
|
||||||
|
|
||||||
|
hits = [
|
||||||
|
RetrievalHit(
|
||||||
|
text=c.text,
|
||||||
|
score=c.score,
|
||||||
|
source=self.name,
|
||||||
|
metadata=c.metadata,
|
||||||
|
)
|
||||||
|
for c in final_candidates
|
||||||
|
if c.score >= self._config.score_threshold
|
||||||
|
]
|
||||||
|
|
||||||
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[EnhancedPipeline] Retrieval completed: hits={len(hits)}, "
|
||||||
|
f"latency_ms={latency_ms:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
result = RetrievalResult(
|
||||||
|
hits=hits,
|
||||||
|
diagnostics={
|
||||||
|
"total_candidates": len(candidates),
|
||||||
|
"after_rerank": self._reranker_config.enabled and ctx.use_reranker,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return PipelineResult(
|
||||||
|
retrieval_result=result,
|
||||||
|
pipeline_name=self.name,
|
||||||
|
used_reranker=self._reranker_config.enabled and ctx.use_reranker,
|
||||||
|
metadata_filter_applied=ctx.metadata_filter is not None,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
diagnostics={
|
||||||
|
"dense_weight": self._config.hybrid.dense_weight,
|
||||||
|
"keyword_weight": self._config.hybrid.keyword_weight,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
|
logger.error(f"[EnhancedPipeline] Retrieval error: {e}", exc_info=True)
|
||||||
|
return self._create_empty_result(ctx, error=str(e), latency_ms=latency_ms)
|
||||||
|
|
||||||
|
async def _hybrid_retrieve(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
query: str,
|
||||||
|
embedding_result: Any,
|
||||||
|
metadata_filter: dict[str, Any] | None = None,
|
||||||
|
kb_ids: list[str] | None = None,
|
||||||
|
) -> list[RetrievalCandidate]:
|
||||||
|
"""混合检索:Dense + Keyword + RRF。"""
|
||||||
|
client = await self._get_client()
|
||||||
|
top_k = self._config.top_k
|
||||||
|
expand_factor = self._config.hybrid.keyword_top_k_multiplier
|
||||||
|
|
||||||
|
vector_task = self._dense_search(
|
||||||
|
client=client,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
embedding=embedding_result.embedding_full,
|
||||||
|
top_k=top_k * expand_factor,
|
||||||
|
metadata_filter=metadata_filter,
|
||||||
|
kb_ids=kb_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
keyword_task = self._keyword_search(
|
||||||
|
client=client,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
query=query,
|
||||||
|
top_k=top_k * expand_factor,
|
||||||
|
metadata_filter=metadata_filter,
|
||||||
|
kb_ids=kb_ids,
|
||||||
|
) if self._config.hybrid.enable_keyword else asyncio.sleep(0, result=[])
|
||||||
|
|
||||||
|
vector_results, keyword_results = await asyncio.gather(
|
||||||
|
vector_task, keyword_task, return_exceptions=True
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(vector_results, Exception):
|
||||||
|
logger.warning(f"[EnhancedPipeline] Dense search failed: {vector_results}")
|
||||||
|
vector_results = []
|
||||||
|
|
||||||
|
if isinstance(keyword_results, Exception):
|
||||||
|
logger.warning(f"[EnhancedPipeline] Keyword search failed: {keyword_results}")
|
||||||
|
keyword_results = []
|
||||||
|
|
||||||
|
combined = self._rrf_combiner.combine(
|
||||||
|
vector_results=vector_results,
|
||||||
|
bm25_results=keyword_results,
|
||||||
|
vector_weight=self._config.hybrid.dense_weight,
|
||||||
|
bm25_weight=self._config.hybrid.keyword_weight,
|
||||||
|
)
|
||||||
|
|
||||||
|
candidates = []
|
||||||
|
for item in combined:
|
||||||
|
candidates.append(RetrievalCandidate(
|
||||||
|
id=item.get("id", ""),
|
||||||
|
text=item.get("payload", {}).get("text", ""),
|
||||||
|
score=item.get("score", 0.0),
|
||||||
|
vector_score=item.get("vector_score", 0.0),
|
||||||
|
keyword_score=item.get("bm25_score", 0.0),
|
||||||
|
metadata=item.get("payload", {}),
|
||||||
|
))
|
||||||
|
|
||||||
|
return candidates
|
||||||
|
|
||||||
|
async def _dense_search(
|
||||||
|
self,
|
||||||
|
client: QdrantClient,
|
||||||
|
tenant_id: str,
|
||||||
|
embedding: list[float],
|
||||||
|
top_k: int,
|
||||||
|
metadata_filter: dict[str, Any] | None = None,
|
||||||
|
kb_ids: list[str] | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Dense 向量检索。"""
|
||||||
|
try:
|
||||||
|
results = await client.search(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
query_vector=embedding,
|
||||||
|
limit=top_k,
|
||||||
|
vector_name="full",
|
||||||
|
metadata_filter=metadata_filter,
|
||||||
|
kb_ids=kb_ids,
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[EnhancedPipeline] Dense search error: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _keyword_search(
|
||||||
|
self,
|
||||||
|
client: QdrantClient,
|
||||||
|
tenant_id: str,
|
||||||
|
query: str,
|
||||||
|
top_k: int,
|
||||||
|
metadata_filter: dict[str, Any] | None = None,
|
||||||
|
kb_ids: list[str] | None = None,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Keyword 关键词检索。"""
|
||||||
|
try:
|
||||||
|
qdrant = await client.get_client()
|
||||||
|
collection_name = client.get_collection_name(tenant_id)
|
||||||
|
|
||||||
|
query_terms = set(re.findall(r'\w+', query.lower()))
|
||||||
|
|
||||||
|
results = await qdrant.scroll(
|
||||||
|
collection_name=collection_name,
|
||||||
|
limit=top_k * 3,
|
||||||
|
with_payload=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
scored_results = []
|
||||||
|
for point in results[0]:
|
||||||
|
text = point.payload.get("text", "").lower()
|
||||||
|
text_terms = set(re.findall(r'\w+', text))
|
||||||
|
overlap = len(query_terms & text_terms)
|
||||||
|
|
||||||
|
if overlap > 0:
|
||||||
|
score = overlap / (len(query_terms) + len(text_terms) - overlap)
|
||||||
|
scored_results.append({
|
||||||
|
"id": str(point.id),
|
||||||
|
"score": score,
|
||||||
|
"payload": point.payload or {},
|
||||||
|
})
|
||||||
|
|
||||||
|
scored_results.sort(key=lambda x: x["score"], reverse=True)
|
||||||
|
return scored_results[:top_k]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"[EnhancedPipeline] Keyword search failed: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _rerank(
|
||||||
|
self,
|
||||||
|
candidates: list[RetrievalCandidate],
|
||||||
|
query: str,
|
||||||
|
) -> list[RetrievalCandidate]:
|
||||||
|
"""可选重排。"""
|
||||||
|
if not candidates:
|
||||||
|
return candidates
|
||||||
|
|
||||||
|
try:
|
||||||
|
provider = await self._get_embedding_provider()
|
||||||
|
query_embedding = await provider.embed_query(query)
|
||||||
|
|
||||||
|
reranked = []
|
||||||
|
for candidate in candidates:
|
||||||
|
candidate_text = candidate.text[:500]
|
||||||
|
if candidate_text:
|
||||||
|
candidate_embedding = await provider.embed(candidate_text)
|
||||||
|
similarity = self._cosine_similarity(
|
||||||
|
query_embedding.embedding_full,
|
||||||
|
candidate_embedding,
|
||||||
|
)
|
||||||
|
candidate.score = similarity
|
||||||
|
|
||||||
|
if candidate.score >= self._reranker_config.min_score_threshold:
|
||||||
|
reranked.append(candidate)
|
||||||
|
|
||||||
|
reranked.sort(key=lambda x: x.score, reverse=True)
|
||||||
|
return reranked[:self._reranker_config.top_k_after_rerank]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[EnhancedPipeline] Rerank failed: {e}")
|
||||||
|
return candidates
|
||||||
|
|
||||||
|
def _cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
|
||||||
|
"""计算余弦相似度。"""
|
||||||
|
a = np.array(vec1)
|
||||||
|
b = np.array(vec2)
|
||||||
|
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
|
||||||
|
|
||||||
|
async def health_check(self) -> bool:
|
||||||
|
"""健康检查。"""
|
||||||
|
try:
|
||||||
|
client = await self._get_client()
|
||||||
|
qdrant = await client.get_client()
|
||||||
|
await qdrant.get_collections()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[EnhancedPipeline] Health check failed: {e}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
_enhanced_pipeline: EnhancedPipeline | None = None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_enhanced_pipeline() -> EnhancedPipeline:
|
||||||
|
"""获取 EnhancedPipeline 单例。"""
|
||||||
|
global _enhanced_pipeline
|
||||||
|
if _enhanced_pipeline is None:
|
||||||
|
_enhanced_pipeline = EnhancedPipeline()
|
||||||
|
return _enhanced_pipeline
|
||||||
|
|
@ -0,0 +1,136 @@
|
||||||
|
"""
|
||||||
|
Metadata Inference Service.
|
||||||
|
[AC-AISVC-RES-04] 元数据推断统一入口。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.services.mid.metadata_filter_builder import (
|
||||||
|
FilterBuildResult,
|
||||||
|
MetadataFilterBuilder,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.strategy.config import (
|
||||||
|
FilterMode,
|
||||||
|
MetadataInferenceConfig,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.strategy.pipeline_base import MetadataFilterResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InferenceContext:
|
||||||
|
"""元数据推断上下文。"""
|
||||||
|
tenant_id: str
|
||||||
|
query: str
|
||||||
|
session_id: str | None = None
|
||||||
|
user_id: str | None = None
|
||||||
|
channel_type: str | None = None
|
||||||
|
existing_context: dict[str, Any] = field(default_factory=dict)
|
||||||
|
slot_state: Any = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InferenceResult:
|
||||||
|
"""元数据推断结果。"""
|
||||||
|
filter_result: MetadataFilterResult
|
||||||
|
inferred_fields: dict[str, Any] = field(default_factory=dict)
|
||||||
|
confidence_scores: dict[str, float] = field(default_factory=dict)
|
||||||
|
overall_confidence: float | None = None
|
||||||
|
inference_source: str = "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataInferenceService:
|
||||||
|
"""
|
||||||
|
元数据推断统一入口。【AC-AISVC-RES-04】
|
||||||
|
|
||||||
|
职责:
|
||||||
|
1. 统一的元数据推断入口(策略无关)
|
||||||
|
2. 根据置信度决定 hard/soft filter 模式
|
||||||
|
3. 与现有 MetadataFilterBuilder 保持一致
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
config: MetadataInferenceConfig | None = None,
|
||||||
|
):
|
||||||
|
self._session = session
|
||||||
|
self._config = config or MetadataInferenceConfig()
|
||||||
|
self._filter_builder: MetadataFilterBuilder | None = None
|
||||||
|
|
||||||
|
async def infer(self, ctx: InferenceContext) -> InferenceResult:
|
||||||
|
"""执行元数据推断。【AC-AISVC-RES-04】"""
|
||||||
|
logger.info(
|
||||||
|
f"[MetadataInference] Starting inference: tenant={ctx.tenant_id}, "
|
||||||
|
f"query={ctx.query[:50]}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._filter_builder is None:
|
||||||
|
self._filter_builder = MetadataFilterBuilder(self._session)
|
||||||
|
|
||||||
|
effective_context = dict(ctx.existing_context)
|
||||||
|
|
||||||
|
if ctx.slot_state:
|
||||||
|
effective_context = await self._merge_slot_state(
|
||||||
|
effective_context, ctx.slot_state
|
||||||
|
)
|
||||||
|
|
||||||
|
build_result = await self._filter_builder.build_filter(
|
||||||
|
tenant_id=ctx.tenant_id,
|
||||||
|
context=effective_context,
|
||||||
|
)
|
||||||
|
|
||||||
|
confidence = self._calculate_confidence(build_result, effective_context)
|
||||||
|
filter_mode = self._config.determine_filter_mode(confidence)
|
||||||
|
|
||||||
|
filter_result = MetadataFilterResult(
|
||||||
|
filter_dict=build_result.applied_filter,
|
||||||
|
filter_mode=filter_mode,
|
||||||
|
confidence=confidence,
|
||||||
|
missing_required_slots=build_result.missing_required_slots,
|
||||||
|
debug_info=build_result.debug_info,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[MetadataInference] Inference completed: filter_mode={filter_mode.value}, "
|
||||||
|
f"confidence={confidence}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return InferenceResult(
|
||||||
|
filter_result=filter_result,
|
||||||
|
inferred_fields=build_result.applied_filter,
|
||||||
|
overall_confidence=confidence,
|
||||||
|
inference_source="metadata_filter_builder",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _merge_slot_state(
|
||||||
|
self, context: dict[str, Any], slot_state: Any
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""合并槽位状态到上下文。"""
|
||||||
|
if hasattr(slot_state, 'filled_slots'):
|
||||||
|
for slot_key, slot_value in slot_state.filled_slots.items():
|
||||||
|
if slot_key not in context:
|
||||||
|
context[slot_key] = slot_value
|
||||||
|
return context
|
||||||
|
|
||||||
|
def _calculate_confidence(
|
||||||
|
self, build_result: FilterBuildResult, context: dict[str, Any]
|
||||||
|
) -> float | None:
|
||||||
|
"""计算推断置信度。"""
|
||||||
|
if build_result.missing_required_slots:
|
||||||
|
return 0.3
|
||||||
|
if not build_result.applied_filter:
|
||||||
|
return None
|
||||||
|
if not context:
|
||||||
|
return 0.5
|
||||||
|
applied_ratio = len(build_result.applied_filter) / max(len(context), 1)
|
||||||
|
if applied_ratio >= 0.8:
|
||||||
|
return 0.9
|
||||||
|
elif applied_ratio >= 0.5:
|
||||||
|
return 0.7
|
||||||
|
return 0.5
|
||||||
|
|
@ -0,0 +1,118 @@
|
||||||
|
"""
|
||||||
|
Mode Router.
|
||||||
|
[AC-AISVC-RES-09~15] 模式路由器。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from app.services.retrieval.strategy.config import ModeRouterConfig, RuntimeMode
|
||||||
|
from app.services.retrieval.strategy.pipeline_base import PipelineResult
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ModeDecision:
|
||||||
|
"""模式决策结果。"""
|
||||||
|
mode: RuntimeMode
|
||||||
|
reason: str
|
||||||
|
confidence: float | None = None
|
||||||
|
complexity_score: float | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ModeRouter:
|
||||||
|
"""
|
||||||
|
模式路由器。【AC-AISVC-RES-09~15】
|
||||||
|
|
||||||
|
职责:
|
||||||
|
1. 根据 rag_runtime_mode 选择 direct/react/auto 模式
|
||||||
|
2. auto 模式下根据复杂度与置信度自动选择路由
|
||||||
|
3. direct 低置信度时触发 react 回退
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, config: ModeRouterConfig | None = None):
|
||||||
|
self._config = config or ModeRouterConfig()
|
||||||
|
|
||||||
|
def decide(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
confidence: float | None = None,
|
||||||
|
complexity_score: float | None = None,
|
||||||
|
) -> ModeDecision:
|
||||||
|
"""决定使用哪种模式。【AC-AISVC-RES-09~13】"""
|
||||||
|
if self._config.runtime_mode == RuntimeMode.REACT:
|
||||||
|
return ModeDecision(mode=RuntimeMode.REACT, reason="runtime_mode=react")
|
||||||
|
|
||||||
|
if self._config.runtime_mode == RuntimeMode.DIRECT:
|
||||||
|
return ModeDecision(mode=RuntimeMode.DIRECT, reason="runtime_mode=direct")
|
||||||
|
|
||||||
|
calculated_complexity = complexity_score or self._calculate_complexity(query)
|
||||||
|
|
||||||
|
if self._should_use_direct(query, confidence, calculated_complexity):
|
||||||
|
return ModeDecision(
|
||||||
|
mode=RuntimeMode.DIRECT,
|
||||||
|
reason="auto: short_query_high_confidence",
|
||||||
|
confidence=confidence,
|
||||||
|
complexity_score=calculated_complexity,
|
||||||
|
)
|
||||||
|
|
||||||
|
return ModeDecision(
|
||||||
|
mode=RuntimeMode.REACT,
|
||||||
|
reason="auto: complex_or_low_confidence",
|
||||||
|
confidence=confidence,
|
||||||
|
complexity_score=calculated_complexity,
|
||||||
|
)
|
||||||
|
|
||||||
|
def should_fallback_to_react(self, direct_result: PipelineResult) -> bool:
|
||||||
|
"""判断是否应该从 direct 回退到 react。【AC-AISVC-RES-14】"""
|
||||||
|
if not self._config.direct_fallback_on_low_confidence:
|
||||||
|
return False
|
||||||
|
if direct_result.is_empty:
|
||||||
|
return True
|
||||||
|
max_score = direct_result.retrieval_result.max_score
|
||||||
|
if max_score < 0.3:
|
||||||
|
return True
|
||||||
|
if direct_result.retrieval_result.hit_count < 2:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _should_use_direct(
|
||||||
|
self, query: str, confidence: float | None, complexity_score: float
|
||||||
|
) -> bool:
|
||||||
|
if len(query) <= self._config.short_query_threshold:
|
||||||
|
if confidence and confidence >= self._config.react_trigger_confidence_threshold:
|
||||||
|
return True
|
||||||
|
if confidence and confidence < self._config.react_trigger_confidence_threshold:
|
||||||
|
return False
|
||||||
|
if complexity_score >= self._config.react_trigger_complexity_score:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _calculate_complexity(self, query: str) -> float:
|
||||||
|
score = 0.0
|
||||||
|
if len(query) > 50:
|
||||||
|
score += 0.2
|
||||||
|
if len(query) > 100:
|
||||||
|
score += 0.2
|
||||||
|
condition_words = ["和", "或", "但是", "如果", "同时", "并且", "或者", "以及"]
|
||||||
|
for word in condition_words:
|
||||||
|
if word in query:
|
||||||
|
score += 0.1
|
||||||
|
return min(score, 1.0)
|
||||||
|
|
||||||
|
def get_config(self) -> ModeRouterConfig:
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
def update_config(self, config: ModeRouterConfig) -> None:
|
||||||
|
self._config = config
|
||||||
|
|
||||||
|
|
||||||
|
_mode_router: ModeRouter | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_mode_router() -> ModeRouter:
|
||||||
|
global _mode_router
|
||||||
|
if _mode_router is None:
|
||||||
|
_mode_router = ModeRouter()
|
||||||
|
return _mode_router
|
||||||
|
|
@ -0,0 +1,116 @@
|
||||||
|
"""
|
||||||
|
Pipeline Base Classes.
|
||||||
|
[AC-AISVC-RES-01~15] Pipeline 抽象基类。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.services.retrieval.base import RetrievalContext, RetrievalHit, RetrievalResult
|
||||||
|
from app.services.retrieval.strategy.config import FilterMode
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MetadataFilterResult:
|
||||||
|
"""元数据过滤结果。"""
|
||||||
|
filter_dict: dict[str, Any] = field(default_factory=dict)
|
||||||
|
filter_mode: FilterMode = FilterMode.NONE
|
||||||
|
confidence: float | None = None
|
||||||
|
missing_required_slots: list[dict[str, str]] = field(default_factory=list)
|
||||||
|
debug_info: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineContext:
|
||||||
|
"""Pipeline 执行上下文。"""
|
||||||
|
retrieval_ctx: RetrievalContext
|
||||||
|
metadata_filter: MetadataFilterResult | None = None
|
||||||
|
use_reranker: bool = False
|
||||||
|
use_react: bool = False
|
||||||
|
react_iteration: int = 0
|
||||||
|
previous_results: list[RetrievalHit] = field(default_factory=list)
|
||||||
|
extra: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tenant_id(self) -> str:
|
||||||
|
return self.retrieval_ctx.tenant_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def query(self) -> str:
|
||||||
|
return self.retrieval_ctx.query
|
||||||
|
|
||||||
|
@property
|
||||||
|
def session_id(self) -> str | None:
|
||||||
|
return self.retrieval_ctx.session_id
|
||||||
|
|
||||||
|
@property
|
||||||
|
def kb_ids(self) -> list[str] | None:
|
||||||
|
return self.retrieval_ctx.kb_ids
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PipelineResult:
|
||||||
|
"""Pipeline 执行结果。"""
|
||||||
|
retrieval_result: RetrievalResult
|
||||||
|
pipeline_name: str = ""
|
||||||
|
used_reranker: bool = False
|
||||||
|
used_react: bool = False
|
||||||
|
react_iterations: int = 0
|
||||||
|
metadata_filter_applied: bool = False
|
||||||
|
fallback_triggered: bool = False
|
||||||
|
fallback_reason: str | None = None
|
||||||
|
latency_ms: float = 0.0
|
||||||
|
diagnostics: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def hits(self) -> list[RetrievalHit]:
|
||||||
|
return self.retrieval_result.hits
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_empty(self) -> bool:
|
||||||
|
return self.retrieval_result.is_empty
|
||||||
|
|
||||||
|
|
||||||
|
class BasePipeline(ABC):
|
||||||
|
"""Pipeline 抽象基类。【AC-AISVC-RES-01, AC-AISVC-RES-02】"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Pipeline 名称。"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@property
|
||||||
|
@abstractmethod
|
||||||
|
def description(self) -> str:
|
||||||
|
"""Pipeline 描述。"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def retrieve(self, ctx: PipelineContext) -> PipelineResult:
|
||||||
|
"""执行检索。"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def health_check(self) -> bool:
|
||||||
|
"""健康检查。"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def _create_empty_result(
|
||||||
|
self,
|
||||||
|
ctx: PipelineContext,
|
||||||
|
error: str | None = None,
|
||||||
|
latency_ms: float = 0.0,
|
||||||
|
) -> PipelineResult:
|
||||||
|
"""创建空结果。"""
|
||||||
|
diagnostics = {"error": error} if error else {}
|
||||||
|
return PipelineResult(
|
||||||
|
retrieval_result=RetrievalResult(hits=[], diagnostics=diagnostics),
|
||||||
|
pipeline_name=self.name,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
diagnostics=diagnostics,
|
||||||
|
)
|
||||||
|
|
@ -0,0 +1,301 @@
|
||||||
|
"""
|
||||||
|
Retrieval Strategy - Unified Entry Point.
|
||||||
|
[AC-AISVC-RES-01~15] 检索策略统一入口。
|
||||||
|
|
||||||
|
整合:
|
||||||
|
- StrategyRouter: 策略路由(default/enhanced)
|
||||||
|
- ModeRouter: 模式路由(direct/react/auto)
|
||||||
|
- MetadataInferenceService: 元数据推断
|
||||||
|
- RollbackManager: 回退管理
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.services.retrieval.base import RetrievalContext, RetrievalResult
|
||||||
|
from app.services.retrieval.strategy.config import (
|
||||||
|
RetrievalStrategyConfig,
|
||||||
|
RuntimeMode,
|
||||||
|
StrategyType,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.strategy.default_pipeline import DefaultPipeline
|
||||||
|
from app.services.retrieval.strategy.enhanced_pipeline import EnhancedPipeline
|
||||||
|
from app.services.retrieval.strategy.metadata_inference import (
|
||||||
|
InferenceContext,
|
||||||
|
MetadataInferenceService,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.strategy.mode_router import ModeDecision, ModeRouter
|
||||||
|
from app.services.retrieval.strategy.pipeline_base import (
|
||||||
|
MetadataFilterResult,
|
||||||
|
PipelineContext,
|
||||||
|
PipelineResult,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.strategy.rollback_manager import (
|
||||||
|
RollbackManager,
|
||||||
|
RollbackTrigger,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.strategy.strategy_router import (
|
||||||
|
RoutingDecision,
|
||||||
|
StrategyRouter,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RetrievalStrategyResult:
|
||||||
|
"""检索策略执行结果。"""
|
||||||
|
|
||||||
|
retrieval_result: RetrievalResult
|
||||||
|
strategy_used: StrategyType
|
||||||
|
mode_used: RuntimeMode
|
||||||
|
metadata_filter: MetadataFilterResult | None
|
||||||
|
latency_ms: float
|
||||||
|
diagnostics: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class RetrievalStrategy:
|
||||||
|
"""
|
||||||
|
检索策略统一入口。【AC-AISVC-RES-01~15】
|
||||||
|
|
||||||
|
整合所有策略组件:
|
||||||
|
1. 元数据推断(MetadataInferenceService)
|
||||||
|
2. 策略路由(StrategyRouter)
|
||||||
|
3. 模式路由(ModeRouter)
|
||||||
|
4. 回退管理(RollbackManager)
|
||||||
|
|
||||||
|
使用方式:
|
||||||
|
```python
|
||||||
|
strategy = RetrievalStrategy(session)
|
||||||
|
result = await strategy.retrieve(
|
||||||
|
tenant_id="tenant_1",
|
||||||
|
query="用户问题",
|
||||||
|
context={"user_id": "user_1"},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
config: RetrievalStrategyConfig | None = None,
|
||||||
|
):
|
||||||
|
self._session = session
|
||||||
|
self._config = config or RetrievalStrategyConfig()
|
||||||
|
|
||||||
|
self._strategy_router = StrategyRouter(self._config)
|
||||||
|
self._mode_router = ModeRouter(self._config.mode_router)
|
||||||
|
self._rollback_manager = RollbackManager(self._config)
|
||||||
|
self._metadata_inference: MetadataInferenceService | None = None
|
||||||
|
|
||||||
|
async def retrieve(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
query: str,
|
||||||
|
context: dict[str, Any] | None = None,
|
||||||
|
kb_ids: list[str] | None = None,
|
||||||
|
session_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
use_reranker: bool = False,
|
||||||
|
use_react: bool = False,
|
||||||
|
) -> RetrievalStrategyResult:
|
||||||
|
"""
|
||||||
|
执行检索策略。【AC-AISVC-RES-01~15】
|
||||||
|
|
||||||
|
流程:
|
||||||
|
1. 元数据推断
|
||||||
|
2. 策略路由
|
||||||
|
3. 模式路由
|
||||||
|
4. 执行检索
|
||||||
|
5. 检查是否需要回退
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: 租户 ID
|
||||||
|
query: 查询文本
|
||||||
|
context: 上下文信息
|
||||||
|
kb_ids: 知识库 ID 列表
|
||||||
|
session_id: 会话 ID
|
||||||
|
user_id: 用户 ID
|
||||||
|
use_reranker: 是否使用重排
|
||||||
|
use_react: 是否使用 ReAct 模式
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
RetrievalStrategyResult 包含检索结果和诊断信息
|
||||||
|
"""
|
||||||
|
start_time = time.time()
|
||||||
|
context = context or {}
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[RetrievalStrategy] Starting retrieval: tenant={tenant_id}, "
|
||||||
|
f"query={query[:50]}..."
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
metadata_filter = await self._infer_metadata(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
query=query,
|
||||||
|
context=context,
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
routing_decision = await self._strategy_router.route(tenant_id, user_id)
|
||||||
|
|
||||||
|
mode_decision = self._mode_router.decide(
|
||||||
|
query=query,
|
||||||
|
confidence=metadata_filter.confidence if metadata_filter else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
retrieval_ctx = RetrievalContext(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
query=query,
|
||||||
|
session_id=session_id,
|
||||||
|
metadata_filter=metadata_filter.filter_dict if metadata_filter else None,
|
||||||
|
kb_ids=kb_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline_ctx = PipelineContext(
|
||||||
|
retrieval_ctx=retrieval_ctx,
|
||||||
|
metadata_filter=metadata_filter,
|
||||||
|
use_reranker=use_reranker or self._config.reranker.enabled,
|
||||||
|
use_react=use_react or mode_decision.mode == RuntimeMode.REACT,
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline_result = await routing_decision.pipeline.retrieve(pipeline_ctx)
|
||||||
|
|
||||||
|
if mode_decision.mode == RuntimeMode.DIRECT and self._mode_router.should_fallback_to_react(pipeline_result):
|
||||||
|
logger.info("[RetrievalStrategy] Falling back to react mode")
|
||||||
|
pipeline_ctx.use_react = True
|
||||||
|
pipeline_result = await routing_decision.pipeline.retrieve(pipeline_ctx)
|
||||||
|
mode_decision = ModeDecision(
|
||||||
|
mode=RuntimeMode.REACT,
|
||||||
|
reason="fallback_from_direct",
|
||||||
|
)
|
||||||
|
|
||||||
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
|
|
||||||
|
self._check_performance(latency_ms, tenant_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[RetrievalStrategy] Retrieval completed: strategy={routing_decision.strategy.value}, "
|
||||||
|
f"mode={mode_decision.mode.value}, hits={len(pipeline_result.hits)}, "
|
||||||
|
f"latency_ms={latency_ms:.2f}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return RetrievalStrategyResult(
|
||||||
|
retrieval_result=pipeline_result.retrieval_result,
|
||||||
|
strategy_used=routing_decision.strategy,
|
||||||
|
mode_used=mode_decision.mode,
|
||||||
|
metadata_filter=metadata_filter,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
diagnostics={
|
||||||
|
"routing_reason": routing_decision.reason,
|
||||||
|
"mode_reason": mode_decision.reason,
|
||||||
|
"grayscale_hit": routing_decision.grayscale_hit,
|
||||||
|
**pipeline_result.diagnostics,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
latency_ms = (time.time() - start_time) * 1000
|
||||||
|
logger.error(
|
||||||
|
f"[RetrievalStrategy] Retrieval error: {e}",
|
||||||
|
exc_info=True
|
||||||
|
)
|
||||||
|
|
||||||
|
self._rollback_manager.rollback(
|
||||||
|
trigger=RollbackTrigger.ERROR,
|
||||||
|
reason=str(e),
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return RetrievalStrategyResult(
|
||||||
|
retrieval_result=RetrievalResult(
|
||||||
|
hits=[],
|
||||||
|
diagnostics={"error": str(e)},
|
||||||
|
),
|
||||||
|
strategy_used=StrategyType.DEFAULT,
|
||||||
|
mode_used=RuntimeMode.DIRECT,
|
||||||
|
metadata_filter=None,
|
||||||
|
latency_ms=latency_ms,
|
||||||
|
diagnostics={"error": str(e)},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _infer_metadata(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
query: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
session_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
|
) -> MetadataFilterResult | None:
|
||||||
|
"""执行元数据推断。"""
|
||||||
|
try:
|
||||||
|
if self._metadata_inference is None:
|
||||||
|
self._metadata_inference = MetadataInferenceService(
|
||||||
|
self._session,
|
||||||
|
self._config.metadata_inference,
|
||||||
|
)
|
||||||
|
|
||||||
|
inference_ctx = InferenceContext(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
query=query,
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
existing_context=context,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await self._metadata_inference.infer(inference_ctx)
|
||||||
|
return result.filter_result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[RetrievalStrategy] Metadata inference failed: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _check_performance(self, latency_ms: float, tenant_id: str | None) -> None:
|
||||||
|
"""检查性能指标,必要时触发回退。"""
|
||||||
|
self._rollback_manager.check_and_rollback(
|
||||||
|
metrics={"latency_ms": latency_ms},
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_config(self) -> RetrievalStrategyConfig:
|
||||||
|
"""获取当前配置。"""
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
def update_config(self, config: RetrievalStrategyConfig) -> None:
|
||||||
|
"""更新配置。"""
|
||||||
|
self._config = config
|
||||||
|
self._strategy_router.update_config(config)
|
||||||
|
self._mode_router.update_config(config.mode_router)
|
||||||
|
self._rollback_manager.update_config(config)
|
||||||
|
|
||||||
|
async def health_check(self) -> dict[str, bool]:
|
||||||
|
"""健康检查。"""
|
||||||
|
results = {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
default_pipeline = await self._strategy_router._get_default_pipeline()
|
||||||
|
results["default_pipeline"] = await default_pipeline.health_check()
|
||||||
|
except Exception:
|
||||||
|
results["default_pipeline"] = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
enhanced_pipeline = await self._strategy_router._get_enhanced_pipeline()
|
||||||
|
results["enhanced_pipeline"] = await enhanced_pipeline.health_check()
|
||||||
|
except Exception:
|
||||||
|
results["enhanced_pipeline"] = False
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def create_retrieval_strategy(
|
||||||
|
session: AsyncSession,
|
||||||
|
config: RetrievalStrategyConfig | None = None,
|
||||||
|
) -> RetrievalStrategy:
|
||||||
|
"""创建 RetrievalStrategy 实例。"""
|
||||||
|
return RetrievalStrategy(session, config)
|
||||||
|
|
@ -0,0 +1,192 @@
|
||||||
|
"""
|
||||||
|
Rollback Manager.
|
||||||
|
[AC-AISVC-RES-07] 策略回退与审计管理器。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.services.retrieval.strategy.config import RetrievalStrategyConfig, StrategyType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class RollbackTrigger(str, Enum):
|
||||||
|
"""回退触发原因。"""
|
||||||
|
MANUAL = "manual"
|
||||||
|
ERROR = "error"
|
||||||
|
PERFORMANCE = "performance"
|
||||||
|
TIMEOUT = "timeout"
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AuditLog:
|
||||||
|
"""审计日志记录。"""
|
||||||
|
timestamp: str
|
||||||
|
action: str
|
||||||
|
from_strategy: str
|
||||||
|
to_strategy: str
|
||||||
|
trigger: str
|
||||||
|
reason: str
|
||||||
|
tenant_id: str | None = None
|
||||||
|
details: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RollbackResult:
|
||||||
|
"""回退结果。"""
|
||||||
|
success: bool
|
||||||
|
previous_strategy: StrategyType
|
||||||
|
current_strategy: StrategyType
|
||||||
|
trigger: RollbackTrigger
|
||||||
|
reason: str
|
||||||
|
audit_log: AuditLog | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class RollbackManager:
|
||||||
|
"""
|
||||||
|
策略回退管理器。【AC-AISVC-RES-07】
|
||||||
|
|
||||||
|
职责:
|
||||||
|
1. 策略异常时回退到默认策略
|
||||||
|
2. 记录审计日志
|
||||||
|
3. 支持手动触发回退
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: RetrievalStrategyConfig | None = None,
|
||||||
|
max_audit_logs: int = 1000,
|
||||||
|
):
|
||||||
|
self._config = config or RetrievalStrategyConfig()
|
||||||
|
self._max_audit_logs = max_audit_logs
|
||||||
|
self._audit_logs: list[AuditLog] = []
|
||||||
|
self._previous_strategy: StrategyType = StrategyType.DEFAULT
|
||||||
|
|
||||||
|
def rollback(
|
||||||
|
self,
|
||||||
|
trigger: RollbackTrigger,
|
||||||
|
reason: str,
|
||||||
|
tenant_id: str | None = None,
|
||||||
|
details: dict[str, Any] | None = None,
|
||||||
|
) -> RollbackResult:
|
||||||
|
"""执行策略回退。【AC-AISVC-RES-07】"""
|
||||||
|
previous = self._config.active_strategy
|
||||||
|
current = StrategyType.DEFAULT
|
||||||
|
|
||||||
|
if previous == StrategyType.DEFAULT:
|
||||||
|
return RollbackResult(
|
||||||
|
success=False,
|
||||||
|
previous_strategy=previous,
|
||||||
|
current_strategy=current,
|
||||||
|
trigger=trigger,
|
||||||
|
reason="Already on default strategy",
|
||||||
|
)
|
||||||
|
|
||||||
|
self._previous_strategy = previous
|
||||||
|
self._config.active_strategy = current
|
||||||
|
|
||||||
|
audit_log = AuditLog(
|
||||||
|
timestamp=datetime.utcnow().isoformat(),
|
||||||
|
action="rollback",
|
||||||
|
from_strategy=previous.value,
|
||||||
|
to_strategy=current.value,
|
||||||
|
trigger=trigger.value,
|
||||||
|
reason=reason,
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
details=details or {},
|
||||||
|
)
|
||||||
|
|
||||||
|
self._add_audit_log(audit_log)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[RollbackManager] Rollback executed: from={previous.value}, "
|
||||||
|
f"to={current.value}, trigger={trigger.value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return RollbackResult(
|
||||||
|
success=True,
|
||||||
|
previous_strategy=previous,
|
||||||
|
current_strategy=current,
|
||||||
|
trigger=trigger,
|
||||||
|
reason=reason,
|
||||||
|
audit_log=audit_log,
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_and_rollback(
|
||||||
|
self, metrics: dict[str, float], tenant_id: str | None = None
|
||||||
|
) -> RollbackResult | None:
|
||||||
|
"""检查性能指标并自动回退。【AC-AISVC-RES-08】"""
|
||||||
|
thresholds = self._config.performance_thresholds
|
||||||
|
|
||||||
|
latency = metrics.get("latency_ms", 0)
|
||||||
|
if latency > thresholds.get("max_latency_ms", 2000):
|
||||||
|
return self.rollback(
|
||||||
|
trigger=RollbackTrigger.PERFORMANCE,
|
||||||
|
reason=f"Latency {latency}ms exceeds threshold",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
error_rate = metrics.get("error_rate", 0)
|
||||||
|
if error_rate > thresholds.get("max_error_rate", 0.05):
|
||||||
|
return self.rollback(
|
||||||
|
trigger=RollbackTrigger.ERROR,
|
||||||
|
reason=f"Error rate {error_rate} exceeds threshold",
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _add_audit_log(self, log: AuditLog) -> None:
|
||||||
|
self._audit_logs.append(log)
|
||||||
|
if len(self._audit_logs) > self._max_audit_logs:
|
||||||
|
self._audit_logs = self._audit_logs[-self._max_audit_logs:]
|
||||||
|
|
||||||
|
def record_audit(
|
||||||
|
self,
|
||||||
|
action: str,
|
||||||
|
details: dict[str, Any],
|
||||||
|
tenant_id: str | None = None,
|
||||||
|
) -> AuditLog:
|
||||||
|
"""记录审计日志。"""
|
||||||
|
audit_log = AuditLog(
|
||||||
|
timestamp=datetime.utcnow().isoformat(),
|
||||||
|
action=action,
|
||||||
|
from_strategy=self._config.active_strategy.value,
|
||||||
|
to_strategy=self._config.active_strategy.value,
|
||||||
|
trigger="n/a",
|
||||||
|
reason=details.get("reason", ""),
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
details=details,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._add_audit_log(audit_log)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[RollbackManager] Audit recorded: action={action}, "
|
||||||
|
f"strategy={self._config.active_strategy.value}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return audit_log
|
||||||
|
|
||||||
|
def get_audit_logs(self, limit: int = 100) -> list[AuditLog]:
|
||||||
|
return self._audit_logs[-limit:]
|
||||||
|
|
||||||
|
def get_config(self) -> RetrievalStrategyConfig:
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
def update_config(self, config: RetrievalStrategyConfig) -> None:
|
||||||
|
self._config = config
|
||||||
|
|
||||||
|
|
||||||
|
_rollback_manager: RollbackManager | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_rollback_manager() -> RollbackManager:
|
||||||
|
global _rollback_manager
|
||||||
|
if _rollback_manager is None:
|
||||||
|
_rollback_manager = RollbackManager()
|
||||||
|
return _rollback_manager
|
||||||
|
|
@ -0,0 +1,109 @@
|
||||||
|
"""
|
||||||
|
Strategy Router.
|
||||||
|
[AC-AISVC-RES-01~03] 策略路由器。
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.services.retrieval.strategy.config import (
|
||||||
|
RetrievalStrategyConfig,
|
||||||
|
StrategyType,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.strategy.default_pipeline import DefaultPipeline, get_default_pipeline
|
||||||
|
from app.services.retrieval.strategy.enhanced_pipeline import EnhancedPipeline, get_enhanced_pipeline
|
||||||
|
from app.services.retrieval.strategy.pipeline_base import BasePipeline
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class RoutingDecision:
|
||||||
|
"""路由决策结果。"""
|
||||||
|
strategy: StrategyType
|
||||||
|
pipeline: BasePipeline
|
||||||
|
reason: str
|
||||||
|
grayscale_hit: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class StrategyRouter:
|
||||||
|
"""
|
||||||
|
策略路由器。【AC-AISVC-RES-01~03】
|
||||||
|
|
||||||
|
职责:
|
||||||
|
1. 根据配置选择默认策略或增强策略
|
||||||
|
2. 支持灰度发布(percentage/allowlist)
|
||||||
|
3. 不影响正在运行的默认策略请求
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: RetrievalStrategyConfig | None = None,
|
||||||
|
default_pipeline: DefaultPipeline | None = None,
|
||||||
|
enhanced_pipeline: EnhancedPipeline | None = None,
|
||||||
|
):
|
||||||
|
self._config = config or RetrievalStrategyConfig()
|
||||||
|
self._default_pipeline = default_pipeline
|
||||||
|
self._enhanced_pipeline = enhanced_pipeline
|
||||||
|
|
||||||
|
async def route(
|
||||||
|
self, tenant_id: str, user_id: str | None = None
|
||||||
|
) -> RoutingDecision:
|
||||||
|
"""路由到合适的策略。【AC-AISVC-RES-01~03】"""
|
||||||
|
if self._config.active_strategy == StrategyType.ENHANCED:
|
||||||
|
pipeline = await self._get_enhanced_pipeline()
|
||||||
|
return RoutingDecision(
|
||||||
|
strategy=StrategyType.ENHANCED,
|
||||||
|
pipeline=pipeline,
|
||||||
|
reason="active_strategy=enhanced",
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._config.grayscale.should_use_enhanced(tenant_id, user_id):
|
||||||
|
pipeline = await self._get_enhanced_pipeline()
|
||||||
|
return RoutingDecision(
|
||||||
|
strategy=StrategyType.ENHANCED,
|
||||||
|
pipeline=pipeline,
|
||||||
|
reason="grayscale_hit",
|
||||||
|
grayscale_hit=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
pipeline = await self._get_default_pipeline()
|
||||||
|
return RoutingDecision(
|
||||||
|
strategy=StrategyType.DEFAULT,
|
||||||
|
pipeline=pipeline,
|
||||||
|
reason="default_strategy",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _get_default_pipeline(self) -> DefaultPipeline:
|
||||||
|
if self._default_pipeline is None:
|
||||||
|
self._default_pipeline = await get_default_pipeline()
|
||||||
|
return self._default_pipeline
|
||||||
|
|
||||||
|
async def _get_enhanced_pipeline(self) -> EnhancedPipeline:
|
||||||
|
if self._enhanced_pipeline is None:
|
||||||
|
self._enhanced_pipeline = await get_enhanced_pipeline()
|
||||||
|
return self._enhanced_pipeline
|
||||||
|
|
||||||
|
def get_config(self) -> RetrievalStrategyConfig:
|
||||||
|
return self._config
|
||||||
|
|
||||||
|
def update_config(self, config: RetrievalStrategyConfig) -> None:
|
||||||
|
self._config = config
|
||||||
|
logger.info(f"[StrategyRouter] Config updated: strategy={config.active_strategy.value}")
|
||||||
|
|
||||||
|
|
||||||
|
_strategy_router: StrategyRouter | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_strategy_router() -> StrategyRouter:
|
||||||
|
global _strategy_router
|
||||||
|
if _strategy_router is None:
|
||||||
|
_strategy_router = StrategyRouter()
|
||||||
|
return _strategy_router
|
||||||
|
|
||||||
|
|
||||||
|
def set_strategy_router(router: StrategyRouter) -> None:
|
||||||
|
"""Set the global strategy router instance."""
|
||||||
|
global _strategy_router
|
||||||
|
_strategy_router = router
|
||||||
|
|
@ -0,0 +1,645 @@
|
||||||
|
"""
|
||||||
|
Unit tests for Retrieval Strategy Module.
|
||||||
|
[AC-AISVC-RES-01~15] Tests for strategy config, pipelines, routers, and rollback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from dataclasses import asdict
|
||||||
|
|
||||||
|
from app.services.retrieval.strategy.config import (
|
||||||
|
FilterMode,
|
||||||
|
GrayscaleConfig,
|
||||||
|
HybridRetrievalConfig,
|
||||||
|
MetadataInferenceConfig,
|
||||||
|
ModeRouterConfig,
|
||||||
|
PipelineConfig,
|
||||||
|
RerankerConfig,
|
||||||
|
RetrievalStrategyConfig,
|
||||||
|
RuntimeMode,
|
||||||
|
StrategyType,
|
||||||
|
get_strategy_config,
|
||||||
|
set_strategy_config,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.strategy.pipeline_base import (
|
||||||
|
BasePipeline,
|
||||||
|
MetadataFilterResult,
|
||||||
|
PipelineContext,
|
||||||
|
PipelineResult,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.strategy.default_pipeline import DefaultPipeline
|
||||||
|
from app.services.retrieval.strategy.enhanced_pipeline import EnhancedPipeline
|
||||||
|
from app.services.retrieval.strategy.strategy_router import (
|
||||||
|
RoutingDecision,
|
||||||
|
StrategyRouter,
|
||||||
|
get_strategy_router,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.strategy.mode_router import (
|
||||||
|
ModeDecision,
|
||||||
|
ModeRouter,
|
||||||
|
get_mode_router,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.strategy.rollback_manager import (
|
||||||
|
AuditLog,
|
||||||
|
RollbackManager,
|
||||||
|
RollbackResult,
|
||||||
|
RollbackTrigger,
|
||||||
|
get_rollback_manager,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.base import RetrievalContext, RetrievalHit, RetrievalResult
|
||||||
|
|
||||||
|
|
||||||
|
class TestStrategyConfig:
|
||||||
|
"""[AC-AISVC-RES-01~15] Tests for strategy configuration models."""
|
||||||
|
|
||||||
|
def test_strategy_type_enum(self):
|
||||||
|
"""[AC-AISVC-RES-01] Strategy type should have default and enhanced values."""
|
||||||
|
assert StrategyType.DEFAULT.value == "default"
|
||||||
|
assert StrategyType.ENHANCED.value == "enhanced"
|
||||||
|
|
||||||
|
def test_runtime_mode_enum(self):
|
||||||
|
"""[AC-AISVC-RES-09] Runtime mode should have direct, react, and auto values."""
|
||||||
|
assert RuntimeMode.DIRECT.value == "direct"
|
||||||
|
assert RuntimeMode.REACT.value == "react"
|
||||||
|
assert RuntimeMode.AUTO.value == "auto"
|
||||||
|
|
||||||
|
def test_filter_mode_enum(self):
|
||||||
|
"""[AC-AISVC-RES-04] Filter mode should have hard, soft, and none values."""
|
||||||
|
assert FilterMode.HARD.value == "hard"
|
||||||
|
assert FilterMode.SOFT.value == "soft"
|
||||||
|
assert FilterMode.NONE.value == "none"
|
||||||
|
|
||||||
|
def test_grayscale_config_default(self):
|
||||||
|
"""[AC-AISVC-RES-03] Default grayscale config should be disabled."""
|
||||||
|
config = GrayscaleConfig()
|
||||||
|
assert config.enabled is False
|
||||||
|
assert config.percentage == 0.0
|
||||||
|
assert config.allowlist == []
|
||||||
|
|
||||||
|
def test_grayscale_config_should_use_enhanced_disabled(self):
|
||||||
|
"""[AC-AISVC-RES-03] Should not use enhanced when grayscale disabled."""
|
||||||
|
config = GrayscaleConfig(enabled=False, percentage=50.0)
|
||||||
|
assert config.should_use_enhanced("tenant_a") is False
|
||||||
|
|
||||||
|
def test_grayscale_config_should_use_enhanced_allowlist(self):
|
||||||
|
"""[AC-AISVC-RES-03] Should use enhanced for tenants in allowlist."""
|
||||||
|
config = GrayscaleConfig(enabled=True, allowlist=["tenant_a", "tenant_b"])
|
||||||
|
assert config.should_use_enhanced("tenant_a") is True
|
||||||
|
assert config.should_use_enhanced("tenant_b") is True
|
||||||
|
assert config.should_use_enhanced("tenant_c") is False
|
||||||
|
|
||||||
|
def test_grayscale_config_should_use_enhanced_percentage(self):
|
||||||
|
"""[AC-AISVC-RES-03] Should use enhanced based on percentage."""
|
||||||
|
config = GrayscaleConfig(enabled=True, percentage=100.0)
|
||||||
|
assert config.should_use_enhanced("any_tenant") is True
|
||||||
|
|
||||||
|
config = GrayscaleConfig(enabled=True, percentage=0.0)
|
||||||
|
assert config.should_use_enhanced("any_tenant") is False
|
||||||
|
|
||||||
|
def test_reranker_config_default(self):
|
||||||
|
"""[AC-AISVC-RES-08] Default reranker config should be disabled."""
|
||||||
|
config = RerankerConfig()
|
||||||
|
assert config.enabled is False
|
||||||
|
assert config.model == "cross-encoder"
|
||||||
|
assert config.top_k_after_rerank == 5
|
||||||
|
|
||||||
|
def test_mode_router_config_default(self):
|
||||||
|
"""[AC-AISVC-RES-09] Default mode router config should be direct."""
|
||||||
|
config = ModeRouterConfig()
|
||||||
|
assert config.runtime_mode == RuntimeMode.DIRECT
|
||||||
|
assert config.react_trigger_confidence_threshold == 0.6
|
||||||
|
assert config.react_max_steps == 5
|
||||||
|
|
||||||
|
def test_mode_router_config_should_use_react_always(self):
|
||||||
|
"""[AC-AISVC-RES-10] React mode should always use react."""
|
||||||
|
config = ModeRouterConfig(runtime_mode=RuntimeMode.REACT)
|
||||||
|
assert config.should_use_react("any query") is True
|
||||||
|
|
||||||
|
def test_mode_router_config_should_use_react_never(self):
|
||||||
|
"""[AC-AISVC-RES-09] Direct mode should never use react."""
|
||||||
|
config = ModeRouterConfig(runtime_mode=RuntimeMode.DIRECT)
|
||||||
|
assert config.should_use_react("any query") is False
|
||||||
|
|
||||||
|
def test_mode_router_config_auto_short_query_high_confidence(self):
|
||||||
|
"""[AC-AISVC-RES-12] Auto mode with short query and high confidence should use direct."""
|
||||||
|
config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO)
|
||||||
|
assert config.should_use_react("短问题", confidence=0.8) is False
|
||||||
|
|
||||||
|
def test_mode_router_config_auto_low_confidence(self):
|
||||||
|
"""[AC-AISVC-RES-13] Auto mode with low confidence should use react."""
|
||||||
|
config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO)
|
||||||
|
assert config.should_use_react("any query", confidence=0.3) is True
|
||||||
|
|
||||||
|
def test_metadata_inference_config_determine_filter_mode(self):
|
||||||
|
"""[AC-AISVC-RES-04] Should determine filter mode based on confidence."""
|
||||||
|
config = MetadataInferenceConfig()
|
||||||
|
|
||||||
|
assert config.determine_filter_mode(0.9) == FilterMode.HARD
|
||||||
|
assert config.determine_filter_mode(0.6) == FilterMode.SOFT
|
||||||
|
assert config.determine_filter_mode(0.3) == FilterMode.NONE
|
||||||
|
assert config.determine_filter_mode(None) == FilterMode.NONE
|
||||||
|
|
||||||
|
def test_pipeline_config_default(self):
|
||||||
|
"""[AC-AISVC-RES-01] Default pipeline config should have sensible defaults."""
|
||||||
|
config = PipelineConfig()
|
||||||
|
assert config.top_k == 5
|
||||||
|
assert config.score_threshold == 0.01
|
||||||
|
assert config.two_stage_enabled is True
|
||||||
|
|
||||||
|
def test_retrieval_strategy_config_default(self):
|
||||||
|
"""[AC-AISVC-RES-01] Default strategy config should use default strategy."""
|
||||||
|
config = RetrievalStrategyConfig()
|
||||||
|
assert config.active_strategy == StrategyType.DEFAULT
|
||||||
|
assert config.grayscale.enabled is False
|
||||||
|
assert config.mode_router.runtime_mode == RuntimeMode.DIRECT
|
||||||
|
|
||||||
|
def test_retrieval_strategy_config_is_enhanced_enabled(self):
|
||||||
|
"""[AC-AISVC-RES-02] Should check if enhanced is enabled."""
|
||||||
|
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
||||||
|
assert config.is_enhanced_enabled("tenant_a") is True
|
||||||
|
|
||||||
|
config = RetrievalStrategyConfig(
|
||||||
|
active_strategy=StrategyType.DEFAULT,
|
||||||
|
grayscale=GrayscaleConfig(enabled=True, allowlist=["tenant_a"]),
|
||||||
|
)
|
||||||
|
assert config.is_enhanced_enabled("tenant_a") is True
|
||||||
|
assert config.is_enhanced_enabled("tenant_b") is False
|
||||||
|
|
||||||
|
def test_retrieval_strategy_config_to_dict(self):
|
||||||
|
"""[AC-AISVC-RES-01] Should convert config to dictionary."""
|
||||||
|
config = RetrievalStrategyConfig()
|
||||||
|
d = config.to_dict()
|
||||||
|
|
||||||
|
assert d["active_strategy"] == "default"
|
||||||
|
assert "grayscale" in d
|
||||||
|
assert "pipeline" in d
|
||||||
|
assert "reranker" in d
|
||||||
|
assert "mode_router" in d
|
||||||
|
|
||||||
|
def test_global_config_functions(self):
|
||||||
|
"""[AC-AISVC-RES-01] Should get and set global config."""
|
||||||
|
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
||||||
|
set_strategy_config(config)
|
||||||
|
|
||||||
|
retrieved = get_strategy_config()
|
||||||
|
assert retrieved.active_strategy == StrategyType.ENHANCED
|
||||||
|
|
||||||
|
set_strategy_config(RetrievalStrategyConfig())
|
||||||
|
|
||||||
|
|
||||||
|
class TestPipelineBase:
|
||||||
|
"""[AC-AISVC-RES-01~02] Tests for pipeline base classes."""
|
||||||
|
|
||||||
|
def test_metadata_filter_result_default(self):
|
||||||
|
"""[AC-AISVC-RES-04] Default metadata filter result should be empty."""
|
||||||
|
result = MetadataFilterResult()
|
||||||
|
assert result.filter_dict == {}
|
||||||
|
assert result.filter_mode == FilterMode.NONE
|
||||||
|
assert result.confidence is None
|
||||||
|
|
||||||
|
def test_pipeline_context_properties(self):
|
||||||
|
"""[AC-AISVC-RES-01] Pipeline context should expose retrieval context properties."""
|
||||||
|
retrieval_ctx = RetrievalContext(
|
||||||
|
tenant_id="tenant_1",
|
||||||
|
query="test query",
|
||||||
|
session_id="session_1",
|
||||||
|
kb_ids=["kb_1"],
|
||||||
|
)
|
||||||
|
pipeline_ctx = PipelineContext(retrieval_ctx=retrieval_ctx)
|
||||||
|
|
||||||
|
assert pipeline_ctx.tenant_id == "tenant_1"
|
||||||
|
assert pipeline_ctx.query == "test query"
|
||||||
|
assert pipeline_ctx.session_id == "session_1"
|
||||||
|
assert pipeline_ctx.kb_ids == ["kb_1"]
|
||||||
|
|
||||||
|
def test_pipeline_result_properties(self):
|
||||||
|
"""[AC-AISVC-RES-01] Pipeline result should expose retrieval result properties."""
|
||||||
|
hits = [
|
||||||
|
RetrievalHit(text="hit 1", score=0.9, source="test", metadata={}),
|
||||||
|
RetrievalHit(text="hit 2", score=0.8, source="test", metadata={}),
|
||||||
|
]
|
||||||
|
retrieval_result = RetrievalResult(hits=hits)
|
||||||
|
pipeline_result = PipelineResult(
|
||||||
|
retrieval_result=retrieval_result,
|
||||||
|
pipeline_name="test_pipeline",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert pipeline_result.hits == hits
|
||||||
|
assert pipeline_result.is_empty is False
|
||||||
|
assert pipeline_result.pipeline_name == "test_pipeline"
|
||||||
|
|
||||||
|
def test_pipeline_result_is_empty(self):
|
||||||
|
"""[AC-AISVC-RES-01] Pipeline result should detect empty results."""
|
||||||
|
pipeline_result = PipelineResult(
|
||||||
|
retrieval_result=RetrievalResult(hits=[]),
|
||||||
|
)
|
||||||
|
assert pipeline_result.is_empty is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestDefaultPipeline:
|
||||||
|
"""[AC-AISVC-RES-01] Tests for default pipeline."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_retriever(self):
|
||||||
|
"""Create a mock optimized retriever."""
|
||||||
|
retriever = AsyncMock()
|
||||||
|
retriever.retrieve = AsyncMock(return_value=RetrievalResult(
|
||||||
|
hits=[
|
||||||
|
RetrievalHit(text="result 1", score=0.9, source="default", metadata={}),
|
||||||
|
],
|
||||||
|
diagnostics={"test": True},
|
||||||
|
))
|
||||||
|
retriever.health_check = AsyncMock(return_value=True)
|
||||||
|
retriever._two_stage_enabled = True
|
||||||
|
retriever._hybrid_enabled = True
|
||||||
|
return retriever
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def pipeline(self, mock_retriever):
|
||||||
|
"""Create a default pipeline with mock retriever."""
|
||||||
|
return DefaultPipeline(optimized_retriever=mock_retriever)
|
||||||
|
|
||||||
|
def test_pipeline_name(self, pipeline):
|
||||||
|
"""[AC-AISVC-RES-01] Pipeline should have correct name."""
|
||||||
|
assert pipeline.name == "default_pipeline"
|
||||||
|
|
||||||
|
def test_pipeline_description(self, pipeline):
|
||||||
|
"""[AC-AISVC-RES-01] Pipeline should have description."""
|
||||||
|
assert "默认" in pipeline.description
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retrieve(self, pipeline, mock_retriever):
|
||||||
|
"""[AC-AISVC-RES-01] Should retrieve results using optimized retriever."""
|
||||||
|
retrieval_ctx = RetrievalContext(
|
||||||
|
tenant_id="tenant_1",
|
||||||
|
query="test query",
|
||||||
|
)
|
||||||
|
pipeline_ctx = PipelineContext(retrieval_ctx=retrieval_ctx)
|
||||||
|
|
||||||
|
result = await pipeline.retrieve(pipeline_ctx)
|
||||||
|
|
||||||
|
assert result.pipeline_name == "default_pipeline"
|
||||||
|
assert len(result.hits) == 1
|
||||||
|
assert result.diagnostics["retriever"] == "OptimizedRetriever"
|
||||||
|
mock_retriever.retrieve.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retrieve_with_metadata_filter(self, pipeline, mock_retriever):
|
||||||
|
"""[AC-AISVC-RES-04] Should apply metadata filter."""
|
||||||
|
retrieval_ctx = RetrievalContext(
|
||||||
|
tenant_id="tenant_1",
|
||||||
|
query="test query",
|
||||||
|
)
|
||||||
|
metadata_filter = MetadataFilterResult(
|
||||||
|
filter_dict={"grade": "初一"},
|
||||||
|
filter_mode=FilterMode.HARD,
|
||||||
|
)
|
||||||
|
pipeline_ctx = PipelineContext(
|
||||||
|
retrieval_ctx=retrieval_ctx,
|
||||||
|
metadata_filter=metadata_filter,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await pipeline.retrieve(pipeline_ctx)
|
||||||
|
|
||||||
|
assert result.metadata_filter_applied is True
|
||||||
|
call_args = mock_retriever.retrieve.call_args[0][0]
|
||||||
|
assert call_args.metadata_filter == {"grade": "初一"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check(self, pipeline, mock_retriever):
|
||||||
|
"""[AC-AISVC-RES-01] Should check health."""
|
||||||
|
result = await pipeline.health_check()
|
||||||
|
assert result is True
|
||||||
|
mock_retriever.health_check.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnhancedPipeline:
|
||||||
|
"""[AC-AISVC-RES-02] Tests for enhanced pipeline."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_qdrant_client(self):
|
||||||
|
"""Create a mock Qdrant client."""
|
||||||
|
client = AsyncMock()
|
||||||
|
client.search = AsyncMock(return_value=[
|
||||||
|
{"id": "1", "score": 0.9, "payload": {"text": "result 1"}},
|
||||||
|
])
|
||||||
|
client.get_client = AsyncMock()
|
||||||
|
return client
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_embedding_provider(self):
|
||||||
|
"""Create a mock embedding provider."""
|
||||||
|
provider = AsyncMock()
|
||||||
|
provider.embed_query = AsyncMock()
|
||||||
|
provider.embed_query.return_value = MagicMock(
|
||||||
|
embedding_full=[0.1] * 768,
|
||||||
|
)
|
||||||
|
provider.embed = AsyncMock(return_value=[0.1] * 768)
|
||||||
|
return provider
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def pipeline(self, mock_qdrant_client, mock_embedding_provider):
|
||||||
|
"""Create an enhanced pipeline with mocks."""
|
||||||
|
pipeline = EnhancedPipeline(qdrant_client=mock_qdrant_client)
|
||||||
|
pipeline._embedding_provider = mock_embedding_provider
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
def test_pipeline_name(self, pipeline):
|
||||||
|
"""[AC-AISVC-RES-02] Pipeline should have correct name."""
|
||||||
|
assert pipeline.name == "enhanced_pipeline"
|
||||||
|
|
||||||
|
def test_pipeline_description(self, pipeline):
|
||||||
|
"""[AC-AISVC-RES-02] Pipeline should have description."""
|
||||||
|
assert "增强" in pipeline.description
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retrieve_basic(self, pipeline):
|
||||||
|
"""[AC-AISVC-RES-02] Should retrieve results using hybrid search."""
|
||||||
|
retrieval_ctx = RetrievalContext(
|
||||||
|
tenant_id="tenant_1",
|
||||||
|
query="test query",
|
||||||
|
)
|
||||||
|
pipeline_ctx = PipelineContext(retrieval_ctx=retrieval_ctx)
|
||||||
|
|
||||||
|
result = await pipeline.retrieve(pipeline_ctx)
|
||||||
|
|
||||||
|
assert result.pipeline_name == "enhanced_pipeline"
|
||||||
|
assert result.diagnostics is not None
|
||||||
|
|
||||||
|
|
||||||
|
class TestStrategyRouter:
|
||||||
|
"""[AC-AISVC-RES-01~03] Tests for strategy router."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_default_pipeline(self):
|
||||||
|
"""Create a mock default pipeline."""
|
||||||
|
pipeline = AsyncMock(spec=DefaultPipeline)
|
||||||
|
pipeline.name = "default_pipeline"
|
||||||
|
pipeline.retrieve = AsyncMock(return_value=PipelineResult(
|
||||||
|
retrieval_result=RetrievalResult(hits=[]),
|
||||||
|
pipeline_name="default_pipeline",
|
||||||
|
))
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_enhanced_pipeline(self):
|
||||||
|
"""Create a mock enhanced pipeline."""
|
||||||
|
pipeline = AsyncMock(spec=EnhancedPipeline)
|
||||||
|
pipeline.name = "enhanced_pipeline"
|
||||||
|
pipeline.retrieve = AsyncMock(return_value=PipelineResult(
|
||||||
|
retrieval_result=RetrievalResult(hits=[]),
|
||||||
|
pipeline_name="enhanced_pipeline",
|
||||||
|
))
|
||||||
|
return pipeline
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def router(self, mock_default_pipeline, mock_enhanced_pipeline):
|
||||||
|
"""Create a strategy router with mock pipelines."""
|
||||||
|
config = RetrievalStrategyConfig()
|
||||||
|
return StrategyRouter(
|
||||||
|
config=config,
|
||||||
|
default_pipeline=mock_default_pipeline,
|
||||||
|
enhanced_pipeline=mock_enhanced_pipeline,
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_route_default_strategy(self, router):
|
||||||
|
"""[AC-AISVC-RES-01] Should route to default strategy by default."""
|
||||||
|
import asyncio
|
||||||
|
decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_1"))
|
||||||
|
|
||||||
|
assert decision.strategy == StrategyType.DEFAULT
|
||||||
|
assert decision.reason == "default_strategy"
|
||||||
|
|
||||||
|
def test_route_enhanced_strategy(self, mock_default_pipeline, mock_enhanced_pipeline):
|
||||||
|
"""[AC-AISVC-RES-02] Should route to enhanced strategy when configured."""
|
||||||
|
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
||||||
|
router = StrategyRouter(
|
||||||
|
config=config,
|
||||||
|
default_pipeline=mock_default_pipeline,
|
||||||
|
enhanced_pipeline=mock_enhanced_pipeline,
|
||||||
|
)
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_1"))
|
||||||
|
|
||||||
|
assert decision.strategy == StrategyType.ENHANCED
|
||||||
|
assert decision.reason == "active_strategy=enhanced"
|
||||||
|
|
||||||
|
def test_route_grayscale_allowlist(self, mock_default_pipeline, mock_enhanced_pipeline):
|
||||||
|
"""[AC-AISVC-RES-03] Should route to enhanced for allowlist tenants."""
|
||||||
|
config = RetrievalStrategyConfig(
|
||||||
|
active_strategy=StrategyType.DEFAULT,
|
||||||
|
grayscale=GrayscaleConfig(enabled=True, allowlist=["tenant_a"]),
|
||||||
|
)
|
||||||
|
router = StrategyRouter(
|
||||||
|
config=config,
|
||||||
|
default_pipeline=mock_default_pipeline,
|
||||||
|
enhanced_pipeline=mock_enhanced_pipeline,
|
||||||
|
)
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_a"))
|
||||||
|
assert decision.strategy == StrategyType.ENHANCED
|
||||||
|
assert decision.grayscale_hit is True
|
||||||
|
|
||||||
|
decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_b"))
|
||||||
|
assert decision.strategy == StrategyType.DEFAULT
|
||||||
|
|
||||||
|
def test_update_config(self, router):
|
||||||
|
"""[AC-AISVC-RES-02] Should update config."""
|
||||||
|
new_config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
||||||
|
router.update_config(new_config)
|
||||||
|
|
||||||
|
assert router.get_config().active_strategy == StrategyType.ENHANCED
|
||||||
|
|
||||||
|
|
||||||
|
class TestModeRouter:
|
||||||
|
"""[AC-AISVC-RES-09~15] Tests for mode router."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def router(self):
|
||||||
|
"""Create a mode router."""
|
||||||
|
return ModeRouter()
|
||||||
|
|
||||||
|
def test_decide_react_mode(self):
|
||||||
|
"""[AC-AISVC-RES-10] Should decide react when configured."""
|
||||||
|
config = ModeRouterConfig(runtime_mode=RuntimeMode.REACT)
|
||||||
|
router = ModeRouter(config)
|
||||||
|
|
||||||
|
decision = router.decide("any query")
|
||||||
|
|
||||||
|
assert decision.mode == RuntimeMode.REACT
|
||||||
|
assert decision.reason == "runtime_mode=react"
|
||||||
|
|
||||||
|
def test_decide_direct_mode(self, router):
|
||||||
|
"""[AC-AISVC-RES-09] Should decide direct when configured."""
|
||||||
|
decision = router.decide("any query")
|
||||||
|
|
||||||
|
assert decision.mode == RuntimeMode.DIRECT
|
||||||
|
assert decision.reason == "runtime_mode=direct"
|
||||||
|
|
||||||
|
def test_decide_auto_short_query_high_confidence(self):
|
||||||
|
"""[AC-AISVC-RES-12] Auto with short query and high confidence should use direct."""
|
||||||
|
config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO)
|
||||||
|
router = ModeRouter(config)
|
||||||
|
|
||||||
|
decision = router.decide("短问题", confidence=0.8)
|
||||||
|
|
||||||
|
assert decision.mode == RuntimeMode.DIRECT
|
||||||
|
|
||||||
|
def test_decide_auto_low_confidence(self):
|
||||||
|
"""[AC-AISVC-RES-13] Auto with low confidence should use react."""
|
||||||
|
config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO)
|
||||||
|
router = ModeRouter(config)
|
||||||
|
|
||||||
|
decision = router.decide("any query", confidence=0.3)
|
||||||
|
|
||||||
|
assert decision.mode == RuntimeMode.REACT
|
||||||
|
|
||||||
|
def test_should_fallback_to_react_empty_results(self, router):
|
||||||
|
"""[AC-AISVC-RES-14] Should fallback to react on empty results."""
|
||||||
|
result = PipelineResult(retrieval_result=RetrievalResult(hits=[]))
|
||||||
|
|
||||||
|
assert router.should_fallback_to_react(result) is True
|
||||||
|
|
||||||
|
def test_should_fallback_to_react_low_score(self, router):
|
||||||
|
"""[AC-AISVC-RES-14] Should fallback to react on low score."""
|
||||||
|
result = PipelineResult(
|
||||||
|
retrieval_result=RetrievalResult(
|
||||||
|
hits=[RetrievalHit(text="test", score=0.1, source="test", metadata={})],
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert router.should_fallback_to_react(result) is True
|
||||||
|
|
||||||
|
def test_should_not_fallback_to_react_disabled(self):
|
||||||
|
"""[AC-AISVC-RES-14] Should not fallback when disabled."""
|
||||||
|
config = ModeRouterConfig(direct_fallback_on_low_confidence=False)
|
||||||
|
router = ModeRouter(config)
|
||||||
|
|
||||||
|
result = PipelineResult(retrieval_result=RetrievalResult(hits=[]))
|
||||||
|
|
||||||
|
assert router.should_fallback_to_react(result) is False
|
||||||
|
|
||||||
|
|
||||||
|
class TestRollbackManager:
|
||||||
|
"""[AC-AISVC-RES-07] Tests for rollback manager."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def manager(self):
|
||||||
|
"""Create a rollback manager."""
|
||||||
|
return RollbackManager()
|
||||||
|
|
||||||
|
def test_rollback_from_enhanced(self, manager):
|
||||||
|
"""[AC-AISVC-RES-07] Should rollback from enhanced to default."""
|
||||||
|
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
||||||
|
manager.update_config(config)
|
||||||
|
|
||||||
|
result = manager.rollback(
|
||||||
|
trigger=RollbackTrigger.MANUAL,
|
||||||
|
reason="Testing rollback",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is True
|
||||||
|
assert result.previous_strategy == StrategyType.ENHANCED
|
||||||
|
assert result.current_strategy == StrategyType.DEFAULT
|
||||||
|
assert result.audit_log is not None
|
||||||
|
|
||||||
|
def test_rollback_already_default(self, manager):
|
||||||
|
"""[AC-AISVC-RES-07] Should not rollback when already on default."""
|
||||||
|
result = manager.rollback(
|
||||||
|
trigger=RollbackTrigger.MANUAL,
|
||||||
|
reason="Testing rollback",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.success is False
|
||||||
|
assert result.reason == "Already on default strategy"
|
||||||
|
|
||||||
|
def test_check_and_rollback_latency(self, manager):
|
||||||
|
"""[AC-AISVC-RES-08] Should rollback on high latency."""
|
||||||
|
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
||||||
|
manager.update_config(config)
|
||||||
|
|
||||||
|
result = manager.check_and_rollback(
|
||||||
|
metrics={"latency_ms": 3000.0},
|
||||||
|
tenant_id="tenant_1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.trigger == RollbackTrigger.PERFORMANCE
|
||||||
|
|
||||||
|
def test_check_and_rollback_error_rate(self, manager):
|
||||||
|
"""[AC-AISVC-RES-08] Should rollback on high error rate."""
|
||||||
|
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
||||||
|
manager.update_config(config)
|
||||||
|
|
||||||
|
result = manager.check_and_rollback(
|
||||||
|
metrics={"error_rate": 0.1},
|
||||||
|
tenant_id="tenant_1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert result.trigger == RollbackTrigger.ERROR
|
||||||
|
|
||||||
|
def test_check_and_rollback_ok(self, manager):
|
||||||
|
"""[AC-AISVC-RES-08] Should not rollback when metrics are ok."""
|
||||||
|
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
||||||
|
manager.update_config(config)
|
||||||
|
|
||||||
|
result = manager.check_and_rollback(
|
||||||
|
metrics={"latency_ms": 100.0, "error_rate": 0.01},
|
||||||
|
tenant_id="tenant_1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_get_audit_logs(self, manager):
|
||||||
|
"""[AC-AISVC-RES-07] Should get audit logs."""
|
||||||
|
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
||||||
|
manager.update_config(config)
|
||||||
|
manager.rollback(trigger=RollbackTrigger.MANUAL, reason="Test")
|
||||||
|
|
||||||
|
logs = manager.get_audit_logs()
|
||||||
|
|
||||||
|
assert len(logs) == 1
|
||||||
|
assert logs[0].action == "rollback"
|
||||||
|
|
||||||
|
def test_record_audit(self, manager):
|
||||||
|
"""[AC-AISVC-RES-07] Should record audit log."""
|
||||||
|
log = manager.record_audit(
|
||||||
|
action="test_action",
|
||||||
|
details={"reason": "Testing"},
|
||||||
|
tenant_id="tenant_1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert log.action == "test_action"
|
||||||
|
assert log.tenant_id == "tenant_1"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSingletonInstances:
|
||||||
|
"""Tests for singleton instance getters."""
|
||||||
|
|
||||||
|
def test_get_mode_router_singleton(self):
|
||||||
|
"""Should return same mode router instance."""
|
||||||
|
from app.services.retrieval.strategy.mode_router import _mode_router
|
||||||
|
|
||||||
|
import app.services.retrieval.strategy.mode_router as module
|
||||||
|
module._mode_router = None
|
||||||
|
|
||||||
|
router1 = get_mode_router()
|
||||||
|
router2 = get_mode_router()
|
||||||
|
|
||||||
|
assert router1 is router2
|
||||||
|
|
||||||
|
def test_get_rollback_manager_singleton(self):
|
||||||
|
"""Should return same rollback manager instance."""
|
||||||
|
from app.services.retrieval.strategy.rollback_manager import _rollback_manager
|
||||||
|
|
||||||
|
import app.services.retrieval.strategy.rollback_manager as module
|
||||||
|
module._rollback_manager = None
|
||||||
|
|
||||||
|
manager1 = get_rollback_manager()
|
||||||
|
manager2 = get_rollback_manager()
|
||||||
|
|
||||||
|
assert manager1 is manager2
|
||||||
|
|
@ -0,0 +1,167 @@
|
||||||
|
---
|
||||||
|
context:
|
||||||
|
module: "ai-service"
|
||||||
|
feature: "AISVC-RES"
|
||||||
|
status: "✅已完成"
|
||||||
|
version: "0.9.0"
|
||||||
|
active_ac_range: "AC-AISVC-RES-01~15"
|
||||||
|
|
||||||
|
spec_references:
|
||||||
|
requirements: "spec/ai-service/iterations/v0.9.0-retrieval-embedding-strategy/requirements.md"
|
||||||
|
design: "spec/ai-service/iterations/v0.9.0-retrieval-embedding-strategy/design.md"
|
||||||
|
tasks: "spec/ai-service/iterations/v0.9.0-retrieval-embedding-strategy/tasks.md"
|
||||||
|
openapi_provider: "spec/ai-service/iterations/v0.9.0-retrieval-embedding-strategy/openapi.provider.yaml"
|
||||||
|
active_version: "0.1.0"
|
||||||
|
|
||||||
|
overall_progress:
|
||||||
|
- "[x] Phase 1: Schema与数据模型定义 (100%) [API与校验]"
|
||||||
|
- "[x] Phase 2: 策略服务层实现 (100%) [策略层与配置]"
|
||||||
|
- "[x] Phase 3: 审计日志与指标埋点 (100%) [观测与灰度验证]"
|
||||||
|
- "[x] Phase 4: API端点实现 (100%) [API与校验]"
|
||||||
|
- "[x] Phase 5: 单元测试与验证 (100%) [验收]"
|
||||||
|
- "[x] Phase 6: 检索策略Pipeline实现 (100%) [检索与嵌入策略]"
|
||||||
|
|
||||||
|
current_phase:
|
||||||
|
goal: "检索与嵌入策略化改造已完成"
|
||||||
|
sub_tasks:
|
||||||
|
- "[x] 创建 Schema 模型 (app/schemas/retrieval_strategy.py)"
|
||||||
|
- "[x] 创建策略服务层 (app/services/retrieval/strategy_service.py)"
|
||||||
|
- "[x] 创建审计日志服务 (app/services/retrieval/strategy_audit.py)"
|
||||||
|
- "[x] 创建指标埋点服务 (app/services/retrieval/strategy_metrics.py)"
|
||||||
|
- "[x] 创建 API 端点 (app/api/admin/retrieval_strategy.py)"
|
||||||
|
- "[x] 创建策略配置模型 (app/services/retrieval/strategy/config.py)"
|
||||||
|
- "[x] 实现 DefaultPipeline(复用现有逻辑)"
|
||||||
|
- "[x] 实现 EnhancedPipeline(新端到端流程)"
|
||||||
|
- "[x] 实现元数据推断统一入口 (MetadataInferenceService)"
|
||||||
|
- "[x] 实现 StrategyRouter 和 ModeRouter"
|
||||||
|
- "[x] 实现 Dense + Keyword + RRF 组合检索"
|
||||||
|
- "[x] 实现可选重排与降级开关"
|
||||||
|
- "[x] 实现 RollbackManager(回退与审计)"
|
||||||
|
- "[x] 实现策略 API 接口"
|
||||||
|
- "[x] 创建单元测试 (tests/test_retrieval_strategy_v2.py)"
|
||||||
|
- "[x] 运行单元测试验证 (51 passed)"
|
||||||
|
|
||||||
|
next_action:
|
||||||
|
immediate: "任务已完成,可进行集成测试"
|
||||||
|
details:
|
||||||
|
file: "ai-service/app/services/retrieval/strategy/__init__.py:1"
|
||||||
|
action: "模块已完整实现,可通过 API 接口测试策略切换功能"
|
||||||
|
reference: "http://localhost:8000/docs"
|
||||||
|
constraints: "新策略可配置启用,不影响默认策略"
|
||||||
|
|
||||||
|
technical_context:
|
||||||
|
module_structure: |
|
||||||
|
ai-service/app/
|
||||||
|
├── api/admin/retrieval_strategy.py (API端点)
|
||||||
|
├── schemas/retrieval_strategy.py (Schema模型)
|
||||||
|
└── services/retrieval/
|
||||||
|
├── strategy_service.py (策略服务)
|
||||||
|
├── strategy_audit.py (审计日志)
|
||||||
|
├── strategy_metrics.py (指标埋点)
|
||||||
|
└── strategy/ (新增 - 策略模块)
|
||||||
|
├── __init__.py (模块导出)
|
||||||
|
├── config.py (策略配置模型)
|
||||||
|
├── pipeline_base.py (Pipeline基类)
|
||||||
|
├── default_pipeline.py (默认策略Pipeline)
|
||||||
|
├── enhanced_pipeline.py (增强策略Pipeline)
|
||||||
|
├── metadata_inference.py (元数据推断统一入口)
|
||||||
|
├── strategy_router.py (策略路由器)
|
||||||
|
├── mode_router.py (模式路由器)
|
||||||
|
└── rollback_manager.py (回退管理器)
|
||||||
|
└── tests/
|
||||||
|
├── test_retrieval_strategy.py (单元测试 - 原有)
|
||||||
|
└── test_retrieval_strategy_v2.py (单元测试 - 新增)
|
||||||
|
key_decisions:
|
||||||
|
- decision: "使用内存存储策略状态,后续可扩展为持久化"
|
||||||
|
reason: "快速实现,满足灰度验证需求"
|
||||||
|
impact: "服务重启后策略状态重置为默认值"
|
||||||
|
- decision: "审计日志使用结构化日志记录"
|
||||||
|
reason: "与现有日志体系一致,便于检索"
|
||||||
|
impact: "需要配置日志聚合系统收集审计日志"
|
||||||
|
- decision: "API与现有strategy_router.py互补"
|
||||||
|
reason: "strategy_router.py负责检索路由逻辑,新增的API负责策略管理"
|
||||||
|
impact: "两者协同工作,API提供管理界面"
|
||||||
|
- decision: "DefaultPipeline 复用现有 OptimizedRetriever 逻辑"
|
||||||
|
reason: "保持线上行为不变,最小化改动风险"
|
||||||
|
impact: "新策略与旧策略完全隔离,可独立灰度"
|
||||||
|
- decision: "EnhancedPipeline 实现新端到端流程"
|
||||||
|
reason: "支持 Dense + Keyword + RRF 组合检索,可选重排"
|
||||||
|
impact: "需要配置启用,不影响默认策略"
|
||||||
|
- decision: "元数据推断统一入口处理 hard/soft filter"
|
||||||
|
reason: "新旧策略共享同一推断逻辑,确保一致性"
|
||||||
|
impact: "置信度高用硬过滤,置信度低用软过滤/加权"
|
||||||
|
code_snippets: |
|
||||||
|
# 使用示例
|
||||||
|
from app.services.retrieval.strategy import (
|
||||||
|
get_strategy_router,
|
||||||
|
get_mode_router,
|
||||||
|
get_rollback_manager,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 策略路由
|
||||||
|
router = get_strategy_router()
|
||||||
|
decision = await router.route(tenant_id, user_id)
|
||||||
|
result = await decision.pipeline.retrieve(ctx)
|
||||||
|
|
||||||
|
# 模式路由
|
||||||
|
mode_router = get_mode_router()
|
||||||
|
mode_decision = mode_router.decide(query, confidence=0.8)
|
||||||
|
|
||||||
|
# 回退管理
|
||||||
|
rollback = get_rollback_manager()
|
||||||
|
rollback.rollback(trigger="manual", reason="测试回退")
|
||||||
|
|
||||||
|
session_history:
|
||||||
|
- session: "Session #1 (2026-03-10)"
|
||||||
|
completed:
|
||||||
|
- "创建 Schema 模型 (app/schemas/retrieval_strategy.py)"
|
||||||
|
- "创建策略服务层 (app/services/retrieval/strategy_service.py)"
|
||||||
|
- "创建审计日志服务 (app/services/retrieval/strategy_audit.py)"
|
||||||
|
- "创建指标埋点服务 (app/services/retrieval/strategy_metrics.py)"
|
||||||
|
- "创建 API 端点 (app/api/admin/retrieval_strategy.py)"
|
||||||
|
- "更新 admin __init__.py 注册新路由"
|
||||||
|
- "更新 main.py 注册新路由"
|
||||||
|
- "创建单元测试 (tests/test_retrieval_strategy.py)"
|
||||||
|
- "运行单元测试验证 (46 passed)"
|
||||||
|
changes:
|
||||||
|
- "新增: ai-service/app/schemas/retrieval_strategy.py"
|
||||||
|
- "新增: ai-service/app/services/retrieval/strategy_service.py"
|
||||||
|
- "新增: ai-service/app/services/retrieval/strategy_audit.py"
|
||||||
|
- "新增: ai-service/app/services/retrieval/strategy_metrics.py"
|
||||||
|
- "新增: ai-service/app/api/admin/retrieval_strategy.py"
|
||||||
|
- "修改: ai-service/app/api/admin/__init__.py"
|
||||||
|
- "修改: ai-service/app/main.py"
|
||||||
|
- "新增: ai-service/tests/test_retrieval_strategy.py"
|
||||||
|
status: "✅ 任务完成"
|
||||||
|
- session: "Session #2 (2026-03-10) - 检索策略Pipeline实现"
|
||||||
|
completed:
|
||||||
|
- "实现策略配置模型 (config.py)"
|
||||||
|
- "实现 Pipeline 基类 (pipeline_base.py)"
|
||||||
|
- "实现 DefaultPipeline(复用现有逻辑)"
|
||||||
|
- "实现 EnhancedPipeline(新端到端流程)"
|
||||||
|
- "实现 MetadataInferenceService"
|
||||||
|
- "实现 StrategyRouter"
|
||||||
|
- "实现 ModeRouter"
|
||||||
|
- "实现 RollbackManager"
|
||||||
|
- "更新模块导出 (__init__.py)"
|
||||||
|
- "创建单元测试 (tests/test_retrieval_strategy_v2.py)"
|
||||||
|
- "运行单元测试验证 (51 passed)"
|
||||||
|
changes:
|
||||||
|
- "新增: ai-service/app/services/retrieval/strategy/__init__.py"
|
||||||
|
- "新增: ai-service/app/services/retrieval/strategy/config.py"
|
||||||
|
- "新增: ai-service/app/services/retrieval/strategy/pipeline_base.py"
|
||||||
|
- "新增: ai-service/app/services/retrieval/strategy/default_pipeline.py"
|
||||||
|
- "新增: ai-service/app/services/retrieval/strategy/enhanced_pipeline.py"
|
||||||
|
- "新增: ai-service/app/services/retrieval/strategy/metadata_inference.py"
|
||||||
|
- "新增: ai-service/app/services/retrieval/strategy/strategy_router.py"
|
||||||
|
- "新增: ai-service/app/services/retrieval/strategy/mode_router.py"
|
||||||
|
- "新增: ai-service/app/services/retrieval/strategy/rollback_manager.py"
|
||||||
|
- "新增: ai-service/tests/test_retrieval_strategy_v2.py"
|
||||||
|
status: "✅ 任务完成"
|
||||||
|
|
||||||
|
startup_guide:
|
||||||
|
- "Step 1: 读取本进度文档(了解当前位置与下一步)"
|
||||||
|
- "Step 2: 读取 spec_references 中定义的模块规范(了解业务与接口约束)"
|
||||||
|
- "Step 3: 通过 API 接口测试策略切换功能"
|
||||||
|
- "Step 4: 运行单元测试验证: pytest tests/test_retrieval_strategy_v2.py -v"
|
||||||
|
---
|
||||||
Loading…
Reference in New Issue