[AC-AISVC-RES-01~15] feat(retrieval): 实现检索策略Pipeline模块

- 新增策略配置模型 (config.py)
  - GrayscaleConfig: 灰度发布配置
  - ModeRouterConfig: 模式路由配置
  - MetadataInferenceConfig: 元数据推断配置

- 新增 Pipeline 实现
  - DefaultPipeline: 复用现有 OptimizedRetriever 逻辑
  - EnhancedPipeline: Dense + Keyword + RRF 组合检索

- 新增路由器
  - StrategyRouter: 策略路由器(default/enhanced)
  - ModeRouter: 模式路由器(direct/react/auto)

- 新增 RollbackManager: 回退与审计管理器
- 新增 MetadataInferenceService: 元数据推断统一入口
- 新增单元测试 (51 passed)
This commit is contained in:
MerCry 2026-03-10 20:50:16 +08:00
parent 9f28498b97
commit 7027097513
12 changed files with 2568 additions and 0 deletions

View File

@ -0,0 +1,102 @@
"""
Retrieval Strategy Module for AI Service.
[AC-AISVC-RES-01~15] 策略化检索与嵌入模块
核心组件
- RetrievalStrategyConfig: 策略配置模型
- BasePipeline: Pipeline 抽象基类
- DefaultPipeline: 默认策略复用现有逻辑
- EnhancedPipeline: 增强策略新端到端流程
- MetadataInferenceService: 元数据推断统一入口
- StrategyRouter: 策略路由器
- ModeRouter: 模式路由器direct/react/auto
- RollbackManager: 回退管理器
"""
from app.services.retrieval.strategy.config import (
FilterMode,
GrayscaleConfig,
HybridRetrievalConfig,
MetadataInferenceConfig,
ModeRouterConfig,
PipelineConfig,
RerankerConfig,
RetrievalStrategyConfig,
RuntimeMode,
StrategyType,
get_strategy_config,
set_strategy_config,
)
from app.services.retrieval.strategy.default_pipeline import (
DefaultPipeline,
get_default_pipeline,
)
from app.services.retrieval.strategy.enhanced_pipeline import (
EnhancedPipeline,
get_enhanced_pipeline,
)
from app.services.retrieval.strategy.metadata_inference import (
InferenceContext,
InferenceResult,
MetadataInferenceService,
)
from app.services.retrieval.strategy.mode_router import (
ModeDecision,
ModeRouter,
get_mode_router,
)
from app.services.retrieval.strategy.pipeline_base import (
BasePipeline,
MetadataFilterResult,
PipelineContext,
PipelineResult,
)
from app.services.retrieval.strategy.rollback_manager import (
AuditLog,
RollbackManager,
RollbackResult,
RollbackTrigger,
get_rollback_manager,
)
from app.services.retrieval.strategy.strategy_router import (
RoutingDecision,
StrategyRouter,
get_strategy_router,
)
__all__ = [
"BasePipeline",
"PipelineContext",
"PipelineResult",
"MetadataFilterResult",
"DefaultPipeline",
"get_default_pipeline",
"EnhancedPipeline",
"get_enhanced_pipeline",
"RetrievalStrategyConfig",
"GrayscaleConfig",
"PipelineConfig",
"RerankerConfig",
"ModeRouterConfig",
"HybridRetrievalConfig",
"MetadataInferenceConfig",
"StrategyType",
"FilterMode",
"RuntimeMode",
"get_strategy_config",
"set_strategy_config",
"MetadataInferenceService",
"InferenceContext",
"InferenceResult",
"StrategyRouter",
"RoutingDecision",
"get_strategy_router",
"ModeRouter",
"ModeDecision",
"get_mode_router",
"RollbackManager",
"RollbackResult",
"RollbackTrigger",
"AuditLog",
"get_rollback_manager",
]

View File

@ -0,0 +1,201 @@
"""
Retrieval Strategy Configuration.
[AC-AISVC-RES-01~15] 检索策略配置模型
"""
import random
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
class StrategyType(str, Enum):
"""策略类型。"""
DEFAULT = "default"
ENHANCED = "enhanced"
class RuntimeMode(str, Enum):
"""运行时模式。"""
DIRECT = "direct"
REACT = "react"
AUTO = "auto"
class FilterMode(str, Enum):
"""过滤模式。"""
HARD = "hard"
SOFT = "soft"
NONE = "none"
@dataclass
class GrayscaleConfig:
"""灰度发布配置。【AC-AISVC-RES-03】"""
enabled: bool = False
percentage: float = 0.0
allowlist: list[str] = field(default_factory=list)
def should_use_enhanced(self, tenant_id: str, user_id: str | None = None) -> bool:
"""判断是否应该使用增强策略。"""
if not self.enabled:
return False
if tenant_id in self.allowlist or (user_id and user_id in self.allowlist):
return True
return random.random() * 100 < self.percentage
@dataclass
class HybridRetrievalConfig:
"""混合检索配置。"""
dense_weight: float = 0.7
keyword_weight: float = 0.3
rrf_k: int = 60
enable_keyword: bool = True
keyword_top_k_multiplier: int = 2
@dataclass
class RerankerConfig:
"""重排器配置。【AC-AISVC-RES-08】"""
enabled: bool = False
model: str = "cross-encoder"
top_k_after_rerank: int = 5
min_score_threshold: float = 0.3
@dataclass
class ModeRouterConfig:
"""模式路由配置。【AC-AISVC-RES-09~15】"""
runtime_mode: RuntimeMode = RuntimeMode.DIRECT
react_trigger_confidence_threshold: float = 0.6
react_trigger_complexity_score: float = 0.5
react_max_steps: int = 5
direct_fallback_on_low_confidence: bool = True
short_query_threshold: int = 20
def should_use_react(
self,
query: str,
confidence: float | None = None,
complexity_score: float | None = None,
) -> bool:
"""判断是否应该使用 ReAct 模式。【AC-AISVC-RES-11~13】"""
if self.runtime_mode == RuntimeMode.REACT:
return True
if self.runtime_mode == RuntimeMode.DIRECT:
return False
if len(query) <= self.short_query_threshold and confidence and confidence >= self.react_trigger_confidence_threshold:
return False
if complexity_score and complexity_score >= self.react_trigger_complexity_score:
return True
if confidence and confidence < self.react_trigger_confidence_threshold:
return True
return False
@dataclass
class MetadataInferenceConfig:
"""元数据推断配置。"""
enabled: bool = True
confidence_high_threshold: float = 0.8
confidence_low_threshold: float = 0.5
default_filter_mode: FilterMode = FilterMode.SOFT
cache_ttl_seconds: int = 300
def determine_filter_mode(self, confidence: float | None) -> FilterMode:
"""根据置信度确定过滤模式。"""
if confidence is None:
return FilterMode.NONE
if confidence >= self.confidence_high_threshold:
return FilterMode.HARD
if confidence >= self.confidence_low_threshold:
return FilterMode.SOFT
return FilterMode.NONE
@dataclass
class PipelineConfig:
"""Pipeline 配置。"""
top_k: int = 5
score_threshold: float = 0.01
min_hits: int = 1
two_stage_enabled: bool = True
two_stage_expand_factor: int = 10
hybrid: HybridRetrievalConfig = field(default_factory=HybridRetrievalConfig)
@dataclass
class RetrievalStrategyConfig:
"""检索策略顶层配置。【AC-AISVC-RES-01~15】"""
active_strategy: StrategyType = StrategyType.DEFAULT
grayscale: GrayscaleConfig = field(default_factory=GrayscaleConfig)
pipeline: PipelineConfig = field(default_factory=PipelineConfig)
reranker: RerankerConfig = field(default_factory=RerankerConfig)
mode_router: ModeRouterConfig = field(default_factory=ModeRouterConfig)
metadata_inference: MetadataInferenceConfig = field(default_factory=MetadataInferenceConfig)
performance_thresholds: dict[str, float] = field(default_factory=lambda: {
"max_latency_ms": 2000.0,
"min_success_rate": 0.95,
"max_error_rate": 0.05,
})
def is_enhanced_enabled(self, tenant_id: str, user_id: str | None = None) -> bool:
"""判断是否启用增强策略。"""
if self.active_strategy == StrategyType.ENHANCED:
return True
return self.grayscale.should_use_enhanced(tenant_id, user_id)
def to_dict(self) -> dict[str, Any]:
"""转换为字典。"""
return {
"active_strategy": self.active_strategy.value,
"grayscale": {
"enabled": self.grayscale.enabled,
"percentage": self.grayscale.percentage,
"allowlist": self.grayscale.allowlist,
},
"pipeline": {
"top_k": self.pipeline.top_k,
"score_threshold": self.pipeline.score_threshold,
"min_hits": self.pipeline.min_hits,
"two_stage_enabled": self.pipeline.two_stage_enabled,
},
"reranker": {
"enabled": self.reranker.enabled,
"model": self.reranker.model,
"top_k_after_rerank": self.reranker.top_k_after_rerank,
},
"mode_router": {
"runtime_mode": self.mode_router.runtime_mode.value,
"react_trigger_confidence_threshold": self.mode_router.react_trigger_confidence_threshold,
},
"metadata_inference": {
"enabled": self.metadata_inference.enabled,
"confidence_high_threshold": self.metadata_inference.confidence_high_threshold,
},
"performance_thresholds": self.performance_thresholds,
}
_global_config: RetrievalStrategyConfig | None = None
def get_strategy_config() -> RetrievalStrategyConfig:
"""获取全局策略配置。"""
global _global_config
if _global_config is None:
_global_config = RetrievalStrategyConfig()
return _global_config
def set_strategy_config(config: RetrievalStrategyConfig) -> None:
"""设置全局策略配置。"""
global _global_config
_global_config = config

View File

@ -0,0 +1,117 @@
"""
Default Pipeline.
[AC-AISVC-RES-01] 默认策略 Pipeline复用现有逻辑
"""
import logging
import time
from typing import Any
from app.services.retrieval.base import RetrievalContext, RetrievalResult
from app.services.retrieval.optimized_retriever import OptimizedRetriever, get_optimized_retriever
from app.services.retrieval.strategy.config import PipelineConfig
from app.services.retrieval.strategy.pipeline_base import (
BasePipeline,
PipelineContext,
PipelineResult,
)
logger = logging.getLogger(__name__)
class DefaultPipeline(BasePipeline):
"""
默认策略 PipelineAC-AISVC-RES-01
复用现有 OptimizedRetriever 逻辑保持线上行为不变
"""
def __init__(
self,
config: PipelineConfig | None = None,
optimized_retriever: OptimizedRetriever | None = None,
):
self._config = config or PipelineConfig()
self._optimized_retriever = optimized_retriever
@property
def name(self) -> str:
return "default_pipeline"
@property
def description(self) -> str:
return "默认检索策略,复用现有 OptimizedRetriever 逻辑。"
async def _get_retriever(self) -> OptimizedRetriever:
if self._optimized_retriever is None:
self._optimized_retriever = await get_optimized_retriever()
return self._optimized_retriever
async def retrieve(self, ctx: PipelineContext) -> PipelineResult:
"""执行默认检索流程。【AC-AISVC-RES-01】"""
start_time = time.time()
logger.info(
f"[DefaultPipeline] Starting retrieval: tenant={ctx.tenant_id}, "
f"query={ctx.query[:50]}..."
)
try:
retriever = await self._get_retriever()
metadata_filter = None
if ctx.metadata_filter:
metadata_filter = ctx.metadata_filter.filter_dict
retrieval_ctx = RetrievalContext(
tenant_id=ctx.tenant_id,
query=ctx.query,
session_id=ctx.session_id,
metadata_filter=metadata_filter,
kb_ids=ctx.kb_ids,
)
result = await retriever.retrieve(retrieval_ctx)
latency_ms = (time.time() - start_time) * 1000
logger.info(
f"[DefaultPipeline] Retrieval completed: hits={len(result.hits)}, "
f"latency_ms={latency_ms:.2f}"
)
return PipelineResult(
retrieval_result=result,
pipeline_name=self.name,
metadata_filter_applied=metadata_filter is not None,
latency_ms=latency_ms,
diagnostics={
"retriever": "OptimizedRetriever",
**(result.diagnostics or {}),
},
)
except Exception as e:
latency_ms = (time.time() - start_time) * 1000
logger.error(f"[DefaultPipeline] Retrieval error: {e}", exc_info=True)
return self._create_empty_result(ctx, error=str(e), latency_ms=latency_ms)
async def health_check(self) -> bool:
"""健康检查。"""
try:
retriever = await self._get_retriever()
return await retriever.health_check()
except Exception as e:
logger.error(f"[DefaultPipeline] Health check failed: {e}")
return False
_default_pipeline: DefaultPipeline | None = None
async def get_default_pipeline() -> DefaultPipeline:
"""获取 DefaultPipeline 单例。"""
global _default_pipeline
if _default_pipeline is None:
_default_pipeline = DefaultPipeline()
return _default_pipeline

View File

@ -0,0 +1,364 @@
"""
Enhanced Pipeline.
[AC-AISVC-RES-02] 增强策略 Pipeline新端到端流程
"""
import asyncio
import logging
import re
import time
from dataclasses import dataclass
from typing import Any
import numpy as np
from app.core.config import get_settings
from app.core.qdrant_client import QdrantClient, get_qdrant_client
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
from app.services.retrieval.base import RetrievalHit, RetrievalResult
from app.services.retrieval.optimized_retriever import RRFCombiner
from app.services.retrieval.strategy.config import (
HybridRetrievalConfig,
PipelineConfig,
RerankerConfig,
)
from app.services.retrieval.strategy.pipeline_base import (
BasePipeline,
PipelineContext,
PipelineResult,
)
logger = logging.getLogger(__name__)
settings = get_settings()
@dataclass
class RetrievalCandidate:
"""检索候选结果。"""
id: str
text: str
score: float
vector_score: float = 0.0
keyword_score: float = 0.0
metadata: dict[str, Any] = None
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
class EnhancedPipeline(BasePipeline):
"""
增强策略 PipelineAC-AISVC-RES-02
新端到端流程
1. Dense 向量检索
2. Keyword 关键词检索
3. RRF 融合排序
4. 可选重排
"""
def __init__(
self,
config: PipelineConfig | None = None,
reranker_config: RerankerConfig | None = None,
qdrant_client: QdrantClient | None = None,
):
self._config = config or PipelineConfig()
self._reranker_config = reranker_config or RerankerConfig()
self._qdrant_client = qdrant_client
self._rrf_combiner = RRFCombiner(k=self._config.hybrid.rrf_k)
self._embedding_provider: NomicEmbeddingProvider | None = None
@property
def name(self) -> str:
return "enhanced_pipeline"
@property
def description(self) -> str:
return "增强检索策略,支持 Dense + Keyword + RRF 组合检索。"
async def _get_client(self) -> QdrantClient:
if self._qdrant_client is None:
self._qdrant_client = await get_qdrant_client()
return self._qdrant_client
async def _get_embedding_provider(self) -> NomicEmbeddingProvider:
if self._embedding_provider is None:
from app.services.embedding.factory import get_embedding_config_manager
manager = get_embedding_config_manager()
provider = await manager.get_provider()
if isinstance(provider, NomicEmbeddingProvider):
self._embedding_provider = provider
else:
self._embedding_provider = NomicEmbeddingProvider(
base_url=settings.ollama_base_url,
model=settings.ollama_embedding_model,
dimension=settings.qdrant_vector_size,
)
return self._embedding_provider
async def retrieve(self, ctx: PipelineContext) -> PipelineResult:
"""执行增强检索流程。【AC-AISVC-RES-02】"""
start_time = time.time()
logger.info(
f"[EnhancedPipeline] Starting retrieval: tenant={ctx.tenant_id}, "
f"query={ctx.query[:50]}..."
)
try:
provider = await self._get_embedding_provider()
embedding_result = await provider.embed_query(ctx.query)
candidates = await self._hybrid_retrieve(
tenant_id=ctx.tenant_id,
query=ctx.query,
embedding_result=embedding_result,
metadata_filter=ctx.metadata_filter.filter_dict if ctx.metadata_filter else None,
kb_ids=ctx.kb_ids,
)
if self._reranker_config.enabled and ctx.use_reranker:
candidates = await self._rerank(
candidates=candidates,
query=ctx.query,
)
top_k = self._config.top_k
final_candidates = candidates[:top_k]
hits = [
RetrievalHit(
text=c.text,
score=c.score,
source=self.name,
metadata=c.metadata,
)
for c in final_candidates
if c.score >= self._config.score_threshold
]
latency_ms = (time.time() - start_time) * 1000
logger.info(
f"[EnhancedPipeline] Retrieval completed: hits={len(hits)}, "
f"latency_ms={latency_ms:.2f}"
)
result = RetrievalResult(
hits=hits,
diagnostics={
"total_candidates": len(candidates),
"after_rerank": self._reranker_config.enabled and ctx.use_reranker,
},
)
return PipelineResult(
retrieval_result=result,
pipeline_name=self.name,
used_reranker=self._reranker_config.enabled and ctx.use_reranker,
metadata_filter_applied=ctx.metadata_filter is not None,
latency_ms=latency_ms,
diagnostics={
"dense_weight": self._config.hybrid.dense_weight,
"keyword_weight": self._config.hybrid.keyword_weight,
},
)
except Exception as e:
latency_ms = (time.time() - start_time) * 1000
logger.error(f"[EnhancedPipeline] Retrieval error: {e}", exc_info=True)
return self._create_empty_result(ctx, error=str(e), latency_ms=latency_ms)
async def _hybrid_retrieve(
self,
tenant_id: str,
query: str,
embedding_result: Any,
metadata_filter: dict[str, Any] | None = None,
kb_ids: list[str] | None = None,
) -> list[RetrievalCandidate]:
"""混合检索Dense + Keyword + RRF。"""
client = await self._get_client()
top_k = self._config.top_k
expand_factor = self._config.hybrid.keyword_top_k_multiplier
vector_task = self._dense_search(
client=client,
tenant_id=tenant_id,
embedding=embedding_result.embedding_full,
top_k=top_k * expand_factor,
metadata_filter=metadata_filter,
kb_ids=kb_ids,
)
keyword_task = self._keyword_search(
client=client,
tenant_id=tenant_id,
query=query,
top_k=top_k * expand_factor,
metadata_filter=metadata_filter,
kb_ids=kb_ids,
) if self._config.hybrid.enable_keyword else asyncio.sleep(0, result=[])
vector_results, keyword_results = await asyncio.gather(
vector_task, keyword_task, return_exceptions=True
)
if isinstance(vector_results, Exception):
logger.warning(f"[EnhancedPipeline] Dense search failed: {vector_results}")
vector_results = []
if isinstance(keyword_results, Exception):
logger.warning(f"[EnhancedPipeline] Keyword search failed: {keyword_results}")
keyword_results = []
combined = self._rrf_combiner.combine(
vector_results=vector_results,
bm25_results=keyword_results,
vector_weight=self._config.hybrid.dense_weight,
bm25_weight=self._config.hybrid.keyword_weight,
)
candidates = []
for item in combined:
candidates.append(RetrievalCandidate(
id=item.get("id", ""),
text=item.get("payload", {}).get("text", ""),
score=item.get("score", 0.0),
vector_score=item.get("vector_score", 0.0),
keyword_score=item.get("bm25_score", 0.0),
metadata=item.get("payload", {}),
))
return candidates
async def _dense_search(
self,
client: QdrantClient,
tenant_id: str,
embedding: list[float],
top_k: int,
metadata_filter: dict[str, Any] | None = None,
kb_ids: list[str] | None = None,
) -> list[dict[str, Any]]:
"""Dense 向量检索。"""
try:
results = await client.search(
tenant_id=tenant_id,
query_vector=embedding,
limit=top_k,
vector_name="full",
metadata_filter=metadata_filter,
kb_ids=kb_ids,
)
return results
except Exception as e:
logger.error(f"[EnhancedPipeline] Dense search error: {e}")
return []
async def _keyword_search(
self,
client: QdrantClient,
tenant_id: str,
query: str,
top_k: int,
metadata_filter: dict[str, Any] | None = None,
kb_ids: list[str] | None = None,
) -> list[dict[str, Any]]:
"""Keyword 关键词检索。"""
try:
qdrant = await client.get_client()
collection_name = client.get_collection_name(tenant_id)
query_terms = set(re.findall(r'\w+', query.lower()))
results = await qdrant.scroll(
collection_name=collection_name,
limit=top_k * 3,
with_payload=True,
)
scored_results = []
for point in results[0]:
text = point.payload.get("text", "").lower()
text_terms = set(re.findall(r'\w+', text))
overlap = len(query_terms & text_terms)
if overlap > 0:
score = overlap / (len(query_terms) + len(text_terms) - overlap)
scored_results.append({
"id": str(point.id),
"score": score,
"payload": point.payload or {},
})
scored_results.sort(key=lambda x: x["score"], reverse=True)
return scored_results[:top_k]
except Exception as e:
logger.debug(f"[EnhancedPipeline] Keyword search failed: {e}")
return []
async def _rerank(
self,
candidates: list[RetrievalCandidate],
query: str,
) -> list[RetrievalCandidate]:
"""可选重排。"""
if not candidates:
return candidates
try:
provider = await self._get_embedding_provider()
query_embedding = await provider.embed_query(query)
reranked = []
for candidate in candidates:
candidate_text = candidate.text[:500]
if candidate_text:
candidate_embedding = await provider.embed(candidate_text)
similarity = self._cosine_similarity(
query_embedding.embedding_full,
candidate_embedding,
)
candidate.score = similarity
if candidate.score >= self._reranker_config.min_score_threshold:
reranked.append(candidate)
reranked.sort(key=lambda x: x.score, reverse=True)
return reranked[:self._reranker_config.top_k_after_rerank]
except Exception as e:
logger.warning(f"[EnhancedPipeline] Rerank failed: {e}")
return candidates
def _cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float:
"""计算余弦相似度。"""
a = np.array(vec1)
b = np.array(vec2)
return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)))
async def health_check(self) -> bool:
"""健康检查。"""
try:
client = await self._get_client()
qdrant = await client.get_client()
await qdrant.get_collections()
return True
except Exception as e:
logger.error(f"[EnhancedPipeline] Health check failed: {e}")
return False
_enhanced_pipeline: EnhancedPipeline | None = None
async def get_enhanced_pipeline() -> EnhancedPipeline:
"""获取 EnhancedPipeline 单例。"""
global _enhanced_pipeline
if _enhanced_pipeline is None:
_enhanced_pipeline = EnhancedPipeline()
return _enhanced_pipeline

View File

@ -0,0 +1,136 @@
"""
Metadata Inference Service.
[AC-AISVC-RES-04] 元数据推断统一入口
"""
import logging
from dataclasses import dataclass, field
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.mid.metadata_filter_builder import (
FilterBuildResult,
MetadataFilterBuilder,
)
from app.services.retrieval.strategy.config import (
FilterMode,
MetadataInferenceConfig,
)
from app.services.retrieval.strategy.pipeline_base import MetadataFilterResult
logger = logging.getLogger(__name__)
@dataclass
class InferenceContext:
"""元数据推断上下文。"""
tenant_id: str
query: str
session_id: str | None = None
user_id: str | None = None
channel_type: str | None = None
existing_context: dict[str, Any] = field(default_factory=dict)
slot_state: Any = None
@dataclass
class InferenceResult:
"""元数据推断结果。"""
filter_result: MetadataFilterResult
inferred_fields: dict[str, Any] = field(default_factory=dict)
confidence_scores: dict[str, float] = field(default_factory=dict)
overall_confidence: float | None = None
inference_source: str = "unknown"
class MetadataInferenceService:
"""
元数据推断统一入口AC-AISVC-RES-04
职责
1. 统一的元数据推断入口策略无关
2. 根据置信度决定 hard/soft filter 模式
3. 与现有 MetadataFilterBuilder 保持一致
"""
def __init__(
self,
session: AsyncSession,
config: MetadataInferenceConfig | None = None,
):
self._session = session
self._config = config or MetadataInferenceConfig()
self._filter_builder: MetadataFilterBuilder | None = None
async def infer(self, ctx: InferenceContext) -> InferenceResult:
"""执行元数据推断。【AC-AISVC-RES-04】"""
logger.info(
f"[MetadataInference] Starting inference: tenant={ctx.tenant_id}, "
f"query={ctx.query[:50]}..."
)
if self._filter_builder is None:
self._filter_builder = MetadataFilterBuilder(self._session)
effective_context = dict(ctx.existing_context)
if ctx.slot_state:
effective_context = await self._merge_slot_state(
effective_context, ctx.slot_state
)
build_result = await self._filter_builder.build_filter(
tenant_id=ctx.tenant_id,
context=effective_context,
)
confidence = self._calculate_confidence(build_result, effective_context)
filter_mode = self._config.determine_filter_mode(confidence)
filter_result = MetadataFilterResult(
filter_dict=build_result.applied_filter,
filter_mode=filter_mode,
confidence=confidence,
missing_required_slots=build_result.missing_required_slots,
debug_info=build_result.debug_info,
)
logger.info(
f"[MetadataInference] Inference completed: filter_mode={filter_mode.value}, "
f"confidence={confidence}"
)
return InferenceResult(
filter_result=filter_result,
inferred_fields=build_result.applied_filter,
overall_confidence=confidence,
inference_source="metadata_filter_builder",
)
async def _merge_slot_state(
self, context: dict[str, Any], slot_state: Any
) -> dict[str, Any]:
"""合并槽位状态到上下文。"""
if hasattr(slot_state, 'filled_slots'):
for slot_key, slot_value in slot_state.filled_slots.items():
if slot_key not in context:
context[slot_key] = slot_value
return context
def _calculate_confidence(
self, build_result: FilterBuildResult, context: dict[str, Any]
) -> float | None:
"""计算推断置信度。"""
if build_result.missing_required_slots:
return 0.3
if not build_result.applied_filter:
return None
if not context:
return 0.5
applied_ratio = len(build_result.applied_filter) / max(len(context), 1)
if applied_ratio >= 0.8:
return 0.9
elif applied_ratio >= 0.5:
return 0.7
return 0.5

View File

@ -0,0 +1,118 @@
"""
Mode Router.
[AC-AISVC-RES-09~15] 模式路由器
"""
import logging
from dataclasses import dataclass
from app.services.retrieval.strategy.config import ModeRouterConfig, RuntimeMode
from app.services.retrieval.strategy.pipeline_base import PipelineResult
logger = logging.getLogger(__name__)
@dataclass
class ModeDecision:
"""模式决策结果。"""
mode: RuntimeMode
reason: str
confidence: float | None = None
complexity_score: float | None = None
class ModeRouter:
"""
模式路由器AC-AISVC-RES-09~15
职责
1. 根据 rag_runtime_mode 选择 direct/react/auto 模式
2. auto 模式下根据复杂度与置信度自动选择路由
3. direct 低置信度时触发 react 回退
"""
def __init__(self, config: ModeRouterConfig | None = None):
self._config = config or ModeRouterConfig()
def decide(
self,
query: str,
confidence: float | None = None,
complexity_score: float | None = None,
) -> ModeDecision:
"""决定使用哪种模式。【AC-AISVC-RES-09~13】"""
if self._config.runtime_mode == RuntimeMode.REACT:
return ModeDecision(mode=RuntimeMode.REACT, reason="runtime_mode=react")
if self._config.runtime_mode == RuntimeMode.DIRECT:
return ModeDecision(mode=RuntimeMode.DIRECT, reason="runtime_mode=direct")
calculated_complexity = complexity_score or self._calculate_complexity(query)
if self._should_use_direct(query, confidence, calculated_complexity):
return ModeDecision(
mode=RuntimeMode.DIRECT,
reason="auto: short_query_high_confidence",
confidence=confidence,
complexity_score=calculated_complexity,
)
return ModeDecision(
mode=RuntimeMode.REACT,
reason="auto: complex_or_low_confidence",
confidence=confidence,
complexity_score=calculated_complexity,
)
def should_fallback_to_react(self, direct_result: PipelineResult) -> bool:
"""判断是否应该从 direct 回退到 react。【AC-AISVC-RES-14】"""
if not self._config.direct_fallback_on_low_confidence:
return False
if direct_result.is_empty:
return True
max_score = direct_result.retrieval_result.max_score
if max_score < 0.3:
return True
if direct_result.retrieval_result.hit_count < 2:
return True
return False
def _should_use_direct(
self, query: str, confidence: float | None, complexity_score: float
) -> bool:
if len(query) <= self._config.short_query_threshold:
if confidence and confidence >= self._config.react_trigger_confidence_threshold:
return True
if confidence and confidence < self._config.react_trigger_confidence_threshold:
return False
if complexity_score >= self._config.react_trigger_complexity_score:
return False
return True
def _calculate_complexity(self, query: str) -> float:
score = 0.0
if len(query) > 50:
score += 0.2
if len(query) > 100:
score += 0.2
condition_words = ["", "", "但是", "如果", "同时", "并且", "或者", "以及"]
for word in condition_words:
if word in query:
score += 0.1
return min(score, 1.0)
def get_config(self) -> ModeRouterConfig:
return self._config
def update_config(self, config: ModeRouterConfig) -> None:
self._config = config
_mode_router: ModeRouter | None = None
def get_mode_router() -> ModeRouter:
global _mode_router
if _mode_router is None:
_mode_router = ModeRouter()
return _mode_router

View File

@ -0,0 +1,116 @@
"""
Pipeline Base Classes.
[AC-AISVC-RES-01~15] Pipeline 抽象基类
"""
import logging
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any
from app.services.retrieval.base import RetrievalContext, RetrievalHit, RetrievalResult
from app.services.retrieval.strategy.config import FilterMode
logger = logging.getLogger(__name__)
@dataclass
class MetadataFilterResult:
"""元数据过滤结果。"""
filter_dict: dict[str, Any] = field(default_factory=dict)
filter_mode: FilterMode = FilterMode.NONE
confidence: float | None = None
missing_required_slots: list[dict[str, str]] = field(default_factory=list)
debug_info: dict[str, Any] = field(default_factory=dict)
@dataclass
class PipelineContext:
"""Pipeline 执行上下文。"""
retrieval_ctx: RetrievalContext
metadata_filter: MetadataFilterResult | None = None
use_reranker: bool = False
use_react: bool = False
react_iteration: int = 0
previous_results: list[RetrievalHit] = field(default_factory=list)
extra: dict[str, Any] = field(default_factory=dict)
@property
def tenant_id(self) -> str:
return self.retrieval_ctx.tenant_id
@property
def query(self) -> str:
return self.retrieval_ctx.query
@property
def session_id(self) -> str | None:
return self.retrieval_ctx.session_id
@property
def kb_ids(self) -> list[str] | None:
return self.retrieval_ctx.kb_ids
@dataclass
class PipelineResult:
"""Pipeline 执行结果。"""
retrieval_result: RetrievalResult
pipeline_name: str = ""
used_reranker: bool = False
used_react: bool = False
react_iterations: int = 0
metadata_filter_applied: bool = False
fallback_triggered: bool = False
fallback_reason: str | None = None
latency_ms: float = 0.0
diagnostics: dict[str, Any] = field(default_factory=dict)
@property
def hits(self) -> list[RetrievalHit]:
return self.retrieval_result.hits
@property
def is_empty(self) -> bool:
return self.retrieval_result.is_empty
class BasePipeline(ABC):
"""Pipeline 抽象基类。【AC-AISVC-RES-01, AC-AISVC-RES-02】"""
@property
@abstractmethod
def name(self) -> str:
"""Pipeline 名称。"""
pass
@property
@abstractmethod
def description(self) -> str:
"""Pipeline 描述。"""
pass
@abstractmethod
async def retrieve(self, ctx: PipelineContext) -> PipelineResult:
"""执行检索。"""
pass
@abstractmethod
async def health_check(self) -> bool:
"""健康检查。"""
pass
def _create_empty_result(
self,
ctx: PipelineContext,
error: str | None = None,
latency_ms: float = 0.0,
) -> PipelineResult:
"""创建空结果。"""
diagnostics = {"error": error} if error else {}
return PipelineResult(
retrieval_result=RetrievalResult(hits=[], diagnostics=diagnostics),
pipeline_name=self.name,
latency_ms=latency_ms,
diagnostics=diagnostics,
)

View File

@ -0,0 +1,301 @@
"""
Retrieval Strategy - Unified Entry Point.
[AC-AISVC-RES-01~15] 检索策略统一入口
整合
- StrategyRouter: 策略路由default/enhanced
- ModeRouter: 模式路由direct/react/auto
- MetadataInferenceService: 元数据推断
- RollbackManager: 回退管理
"""
import logging
import time
from dataclasses import dataclass
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.retrieval.base import RetrievalContext, RetrievalResult
from app.services.retrieval.strategy.config import (
RetrievalStrategyConfig,
RuntimeMode,
StrategyType,
)
from app.services.retrieval.strategy.default_pipeline import DefaultPipeline
from app.services.retrieval.strategy.enhanced_pipeline import EnhancedPipeline
from app.services.retrieval.strategy.metadata_inference import (
InferenceContext,
MetadataInferenceService,
)
from app.services.retrieval.strategy.mode_router import ModeDecision, ModeRouter
from app.services.retrieval.strategy.pipeline_base import (
MetadataFilterResult,
PipelineContext,
PipelineResult,
)
from app.services.retrieval.strategy.rollback_manager import (
RollbackManager,
RollbackTrigger,
)
from app.services.retrieval.strategy.strategy_router import (
RoutingDecision,
StrategyRouter,
)
logger = logging.getLogger(__name__)
@dataclass
class RetrievalStrategyResult:
"""检索策略执行结果。"""
retrieval_result: RetrievalResult
strategy_used: StrategyType
mode_used: RuntimeMode
metadata_filter: MetadataFilterResult | None
latency_ms: float
diagnostics: dict[str, Any]
class RetrievalStrategy:
"""
检索策略统一入口AC-AISVC-RES-01~15
整合所有策略组件
1. 元数据推断MetadataInferenceService
2. 策略路由StrategyRouter
3. 模式路由ModeRouter
4. 回退管理RollbackManager
使用方式
```python
strategy = RetrievalStrategy(session)
result = await strategy.retrieve(
tenant_id="tenant_1",
query="用户问题",
context={"user_id": "user_1"},
)
```
"""
def __init__(
self,
session: AsyncSession,
config: RetrievalStrategyConfig | None = None,
):
self._session = session
self._config = config or RetrievalStrategyConfig()
self._strategy_router = StrategyRouter(self._config)
self._mode_router = ModeRouter(self._config.mode_router)
self._rollback_manager = RollbackManager(self._config)
self._metadata_inference: MetadataInferenceService | None = None
async def retrieve(
self,
tenant_id: str,
query: str,
context: dict[str, Any] | None = None,
kb_ids: list[str] | None = None,
session_id: str | None = None,
user_id: str | None = None,
use_reranker: bool = False,
use_react: bool = False,
) -> RetrievalStrategyResult:
"""
执行检索策略AC-AISVC-RES-01~15
流程
1. 元数据推断
2. 策略路由
3. 模式路由
4. 执行检索
5. 检查是否需要回退
Args:
tenant_id: 租户 ID
query: 查询文本
context: 上下文信息
kb_ids: 知识库 ID 列表
session_id: 会话 ID
user_id: 用户 ID
use_reranker: 是否使用重排
use_react: 是否使用 ReAct 模式
Returns:
RetrievalStrategyResult 包含检索结果和诊断信息
"""
start_time = time.time()
context = context or {}
logger.info(
f"[RetrievalStrategy] Starting retrieval: tenant={tenant_id}, "
f"query={query[:50]}..."
)
try:
metadata_filter = await self._infer_metadata(
tenant_id=tenant_id,
query=query,
context=context,
session_id=session_id,
user_id=user_id,
)
routing_decision = await self._strategy_router.route(tenant_id, user_id)
mode_decision = self._mode_router.decide(
query=query,
confidence=metadata_filter.confidence if metadata_filter else None,
)
retrieval_ctx = RetrievalContext(
tenant_id=tenant_id,
query=query,
session_id=session_id,
metadata_filter=metadata_filter.filter_dict if metadata_filter else None,
kb_ids=kb_ids,
)
pipeline_ctx = PipelineContext(
retrieval_ctx=retrieval_ctx,
metadata_filter=metadata_filter,
use_reranker=use_reranker or self._config.reranker.enabled,
use_react=use_react or mode_decision.mode == RuntimeMode.REACT,
)
pipeline_result = await routing_decision.pipeline.retrieve(pipeline_ctx)
if mode_decision.mode == RuntimeMode.DIRECT and self._mode_router.should_fallback_to_react(pipeline_result):
logger.info("[RetrievalStrategy] Falling back to react mode")
pipeline_ctx.use_react = True
pipeline_result = await routing_decision.pipeline.retrieve(pipeline_ctx)
mode_decision = ModeDecision(
mode=RuntimeMode.REACT,
reason="fallback_from_direct",
)
latency_ms = (time.time() - start_time) * 1000
self._check_performance(latency_ms, tenant_id)
logger.info(
f"[RetrievalStrategy] Retrieval completed: strategy={routing_decision.strategy.value}, "
f"mode={mode_decision.mode.value}, hits={len(pipeline_result.hits)}, "
f"latency_ms={latency_ms:.2f}"
)
return RetrievalStrategyResult(
retrieval_result=pipeline_result.retrieval_result,
strategy_used=routing_decision.strategy,
mode_used=mode_decision.mode,
metadata_filter=metadata_filter,
latency_ms=latency_ms,
diagnostics={
"routing_reason": routing_decision.reason,
"mode_reason": mode_decision.reason,
"grayscale_hit": routing_decision.grayscale_hit,
**pipeline_result.diagnostics,
},
)
except Exception as e:
latency_ms = (time.time() - start_time) * 1000
logger.error(
f"[RetrievalStrategy] Retrieval error: {e}",
exc_info=True
)
self._rollback_manager.rollback(
trigger=RollbackTrigger.ERROR,
reason=str(e),
tenant_id=tenant_id,
)
return RetrievalStrategyResult(
retrieval_result=RetrievalResult(
hits=[],
diagnostics={"error": str(e)},
),
strategy_used=StrategyType.DEFAULT,
mode_used=RuntimeMode.DIRECT,
metadata_filter=None,
latency_ms=latency_ms,
diagnostics={"error": str(e)},
)
async def _infer_metadata(
self,
tenant_id: str,
query: str,
context: dict[str, Any],
session_id: str | None = None,
user_id: str | None = None,
) -> MetadataFilterResult | None:
"""执行元数据推断。"""
try:
if self._metadata_inference is None:
self._metadata_inference = MetadataInferenceService(
self._session,
self._config.metadata_inference,
)
inference_ctx = InferenceContext(
tenant_id=tenant_id,
query=query,
session_id=session_id,
user_id=user_id,
existing_context=context,
)
result = await self._metadata_inference.infer(inference_ctx)
return result.filter_result
except Exception as e:
logger.warning(f"[RetrievalStrategy] Metadata inference failed: {e}")
return None
def _check_performance(self, latency_ms: float, tenant_id: str | None) -> None:
"""检查性能指标,必要时触发回退。"""
self._rollback_manager.check_and_rollback(
metrics={"latency_ms": latency_ms},
tenant_id=tenant_id,
)
def get_config(self) -> RetrievalStrategyConfig:
"""获取当前配置。"""
return self._config
def update_config(self, config: RetrievalStrategyConfig) -> None:
"""更新配置。"""
self._config = config
self._strategy_router.update_config(config)
self._mode_router.update_config(config.mode_router)
self._rollback_manager.update_config(config)
async def health_check(self) -> dict[str, bool]:
"""健康检查。"""
results = {}
try:
default_pipeline = await self._strategy_router._get_default_pipeline()
results["default_pipeline"] = await default_pipeline.health_check()
except Exception:
results["default_pipeline"] = False
try:
enhanced_pipeline = await self._strategy_router._get_enhanced_pipeline()
results["enhanced_pipeline"] = await enhanced_pipeline.health_check()
except Exception:
results["enhanced_pipeline"] = False
return results
async def create_retrieval_strategy(
session: AsyncSession,
config: RetrievalStrategyConfig | None = None,
) -> RetrievalStrategy:
"""创建 RetrievalStrategy 实例。"""
return RetrievalStrategy(session, config)

View File

@ -0,0 +1,192 @@
"""
Rollback Manager.
[AC-AISVC-RES-07] 策略回退与审计管理器
"""
import logging
from dataclasses import dataclass, field
from datetime import datetime
from enum import Enum
from typing import Any
from app.services.retrieval.strategy.config import RetrievalStrategyConfig, StrategyType
logger = logging.getLogger(__name__)
class RollbackTrigger(str, Enum):
"""回退触发原因。"""
MANUAL = "manual"
ERROR = "error"
PERFORMANCE = "performance"
TIMEOUT = "timeout"
@dataclass
class AuditLog:
"""审计日志记录。"""
timestamp: str
action: str
from_strategy: str
to_strategy: str
trigger: str
reason: str
tenant_id: str | None = None
details: dict[str, Any] = field(default_factory=dict)
@dataclass
class RollbackResult:
"""回退结果。"""
success: bool
previous_strategy: StrategyType
current_strategy: StrategyType
trigger: RollbackTrigger
reason: str
audit_log: AuditLog | None = None
class RollbackManager:
"""
策略回退管理器AC-AISVC-RES-07
职责
1. 策略异常时回退到默认策略
2. 记录审计日志
3. 支持手动触发回退
"""
def __init__(
self,
config: RetrievalStrategyConfig | None = None,
max_audit_logs: int = 1000,
):
self._config = config or RetrievalStrategyConfig()
self._max_audit_logs = max_audit_logs
self._audit_logs: list[AuditLog] = []
self._previous_strategy: StrategyType = StrategyType.DEFAULT
def rollback(
self,
trigger: RollbackTrigger,
reason: str,
tenant_id: str | None = None,
details: dict[str, Any] | None = None,
) -> RollbackResult:
"""执行策略回退。【AC-AISVC-RES-07】"""
previous = self._config.active_strategy
current = StrategyType.DEFAULT
if previous == StrategyType.DEFAULT:
return RollbackResult(
success=False,
previous_strategy=previous,
current_strategy=current,
trigger=trigger,
reason="Already on default strategy",
)
self._previous_strategy = previous
self._config.active_strategy = current
audit_log = AuditLog(
timestamp=datetime.utcnow().isoformat(),
action="rollback",
from_strategy=previous.value,
to_strategy=current.value,
trigger=trigger.value,
reason=reason,
tenant_id=tenant_id,
details=details or {},
)
self._add_audit_log(audit_log)
logger.info(
f"[RollbackManager] Rollback executed: from={previous.value}, "
f"to={current.value}, trigger={trigger.value}"
)
return RollbackResult(
success=True,
previous_strategy=previous,
current_strategy=current,
trigger=trigger,
reason=reason,
audit_log=audit_log,
)
def check_and_rollback(
self, metrics: dict[str, float], tenant_id: str | None = None
) -> RollbackResult | None:
"""检查性能指标并自动回退。【AC-AISVC-RES-08】"""
thresholds = self._config.performance_thresholds
latency = metrics.get("latency_ms", 0)
if latency > thresholds.get("max_latency_ms", 2000):
return self.rollback(
trigger=RollbackTrigger.PERFORMANCE,
reason=f"Latency {latency}ms exceeds threshold",
tenant_id=tenant_id,
)
error_rate = metrics.get("error_rate", 0)
if error_rate > thresholds.get("max_error_rate", 0.05):
return self.rollback(
trigger=RollbackTrigger.ERROR,
reason=f"Error rate {error_rate} exceeds threshold",
tenant_id=tenant_id,
)
return None
def _add_audit_log(self, log: AuditLog) -> None:
self._audit_logs.append(log)
if len(self._audit_logs) > self._max_audit_logs:
self._audit_logs = self._audit_logs[-self._max_audit_logs:]
def record_audit(
self,
action: str,
details: dict[str, Any],
tenant_id: str | None = None,
) -> AuditLog:
"""记录审计日志。"""
audit_log = AuditLog(
timestamp=datetime.utcnow().isoformat(),
action=action,
from_strategy=self._config.active_strategy.value,
to_strategy=self._config.active_strategy.value,
trigger="n/a",
reason=details.get("reason", ""),
tenant_id=tenant_id,
details=details,
)
self._add_audit_log(audit_log)
logger.info(
f"[RollbackManager] Audit recorded: action={action}, "
f"strategy={self._config.active_strategy.value}"
)
return audit_log
def get_audit_logs(self, limit: int = 100) -> list[AuditLog]:
return self._audit_logs[-limit:]
def get_config(self) -> RetrievalStrategyConfig:
return self._config
def update_config(self, config: RetrievalStrategyConfig) -> None:
self._config = config
_rollback_manager: RollbackManager | None = None
def get_rollback_manager() -> RollbackManager:
global _rollback_manager
if _rollback_manager is None:
_rollback_manager = RollbackManager()
return _rollback_manager

View File

@ -0,0 +1,109 @@
"""
Strategy Router.
[AC-AISVC-RES-01~03] 策略路由器
"""
import logging
from dataclasses import dataclass
from typing import Any
from app.services.retrieval.strategy.config import (
RetrievalStrategyConfig,
StrategyType,
)
from app.services.retrieval.strategy.default_pipeline import DefaultPipeline, get_default_pipeline
from app.services.retrieval.strategy.enhanced_pipeline import EnhancedPipeline, get_enhanced_pipeline
from app.services.retrieval.strategy.pipeline_base import BasePipeline
logger = logging.getLogger(__name__)
@dataclass
class RoutingDecision:
"""路由决策结果。"""
strategy: StrategyType
pipeline: BasePipeline
reason: str
grayscale_hit: bool = False
class StrategyRouter:
"""
策略路由器AC-AISVC-RES-01~03
职责
1. 根据配置选择默认策略或增强策略
2. 支持灰度发布percentage/allowlist
3. 不影响正在运行的默认策略请求
"""
def __init__(
self,
config: RetrievalStrategyConfig | None = None,
default_pipeline: DefaultPipeline | None = None,
enhanced_pipeline: EnhancedPipeline | None = None,
):
self._config = config or RetrievalStrategyConfig()
self._default_pipeline = default_pipeline
self._enhanced_pipeline = enhanced_pipeline
async def route(
self, tenant_id: str, user_id: str | None = None
) -> RoutingDecision:
"""路由到合适的策略。【AC-AISVC-RES-01~03】"""
if self._config.active_strategy == StrategyType.ENHANCED:
pipeline = await self._get_enhanced_pipeline()
return RoutingDecision(
strategy=StrategyType.ENHANCED,
pipeline=pipeline,
reason="active_strategy=enhanced",
)
if self._config.grayscale.should_use_enhanced(tenant_id, user_id):
pipeline = await self._get_enhanced_pipeline()
return RoutingDecision(
strategy=StrategyType.ENHANCED,
pipeline=pipeline,
reason="grayscale_hit",
grayscale_hit=True,
)
pipeline = await self._get_default_pipeline()
return RoutingDecision(
strategy=StrategyType.DEFAULT,
pipeline=pipeline,
reason="default_strategy",
)
async def _get_default_pipeline(self) -> DefaultPipeline:
if self._default_pipeline is None:
self._default_pipeline = await get_default_pipeline()
return self._default_pipeline
async def _get_enhanced_pipeline(self) -> EnhancedPipeline:
if self._enhanced_pipeline is None:
self._enhanced_pipeline = await get_enhanced_pipeline()
return self._enhanced_pipeline
def get_config(self) -> RetrievalStrategyConfig:
return self._config
def update_config(self, config: RetrievalStrategyConfig) -> None:
self._config = config
logger.info(f"[StrategyRouter] Config updated: strategy={config.active_strategy.value}")
_strategy_router: StrategyRouter | None = None
def get_strategy_router() -> StrategyRouter:
global _strategy_router
if _strategy_router is None:
_strategy_router = StrategyRouter()
return _strategy_router
def set_strategy_router(router: StrategyRouter) -> None:
"""Set the global strategy router instance."""
global _strategy_router
_strategy_router = router

View File

@ -0,0 +1,645 @@
"""
Unit tests for Retrieval Strategy Module.
[AC-AISVC-RES-01~15] Tests for strategy config, pipelines, routers, and rollback.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from dataclasses import asdict
from app.services.retrieval.strategy.config import (
FilterMode,
GrayscaleConfig,
HybridRetrievalConfig,
MetadataInferenceConfig,
ModeRouterConfig,
PipelineConfig,
RerankerConfig,
RetrievalStrategyConfig,
RuntimeMode,
StrategyType,
get_strategy_config,
set_strategy_config,
)
from app.services.retrieval.strategy.pipeline_base import (
BasePipeline,
MetadataFilterResult,
PipelineContext,
PipelineResult,
)
from app.services.retrieval.strategy.default_pipeline import DefaultPipeline
from app.services.retrieval.strategy.enhanced_pipeline import EnhancedPipeline
from app.services.retrieval.strategy.strategy_router import (
RoutingDecision,
StrategyRouter,
get_strategy_router,
)
from app.services.retrieval.strategy.mode_router import (
ModeDecision,
ModeRouter,
get_mode_router,
)
from app.services.retrieval.strategy.rollback_manager import (
AuditLog,
RollbackManager,
RollbackResult,
RollbackTrigger,
get_rollback_manager,
)
from app.services.retrieval.base import RetrievalContext, RetrievalHit, RetrievalResult
class TestStrategyConfig:
"""[AC-AISVC-RES-01~15] Tests for strategy configuration models."""
def test_strategy_type_enum(self):
"""[AC-AISVC-RES-01] Strategy type should have default and enhanced values."""
assert StrategyType.DEFAULT.value == "default"
assert StrategyType.ENHANCED.value == "enhanced"
def test_runtime_mode_enum(self):
"""[AC-AISVC-RES-09] Runtime mode should have direct, react, and auto values."""
assert RuntimeMode.DIRECT.value == "direct"
assert RuntimeMode.REACT.value == "react"
assert RuntimeMode.AUTO.value == "auto"
def test_filter_mode_enum(self):
"""[AC-AISVC-RES-04] Filter mode should have hard, soft, and none values."""
assert FilterMode.HARD.value == "hard"
assert FilterMode.SOFT.value == "soft"
assert FilterMode.NONE.value == "none"
def test_grayscale_config_default(self):
"""[AC-AISVC-RES-03] Default grayscale config should be disabled."""
config = GrayscaleConfig()
assert config.enabled is False
assert config.percentage == 0.0
assert config.allowlist == []
def test_grayscale_config_should_use_enhanced_disabled(self):
"""[AC-AISVC-RES-03] Should not use enhanced when grayscale disabled."""
config = GrayscaleConfig(enabled=False, percentage=50.0)
assert config.should_use_enhanced("tenant_a") is False
def test_grayscale_config_should_use_enhanced_allowlist(self):
"""[AC-AISVC-RES-03] Should use enhanced for tenants in allowlist."""
config = GrayscaleConfig(enabled=True, allowlist=["tenant_a", "tenant_b"])
assert config.should_use_enhanced("tenant_a") is True
assert config.should_use_enhanced("tenant_b") is True
assert config.should_use_enhanced("tenant_c") is False
def test_grayscale_config_should_use_enhanced_percentage(self):
"""[AC-AISVC-RES-03] Should use enhanced based on percentage."""
config = GrayscaleConfig(enabled=True, percentage=100.0)
assert config.should_use_enhanced("any_tenant") is True
config = GrayscaleConfig(enabled=True, percentage=0.0)
assert config.should_use_enhanced("any_tenant") is False
def test_reranker_config_default(self):
"""[AC-AISVC-RES-08] Default reranker config should be disabled."""
config = RerankerConfig()
assert config.enabled is False
assert config.model == "cross-encoder"
assert config.top_k_after_rerank == 5
def test_mode_router_config_default(self):
"""[AC-AISVC-RES-09] Default mode router config should be direct."""
config = ModeRouterConfig()
assert config.runtime_mode == RuntimeMode.DIRECT
assert config.react_trigger_confidence_threshold == 0.6
assert config.react_max_steps == 5
def test_mode_router_config_should_use_react_always(self):
"""[AC-AISVC-RES-10] React mode should always use react."""
config = ModeRouterConfig(runtime_mode=RuntimeMode.REACT)
assert config.should_use_react("any query") is True
def test_mode_router_config_should_use_react_never(self):
"""[AC-AISVC-RES-09] Direct mode should never use react."""
config = ModeRouterConfig(runtime_mode=RuntimeMode.DIRECT)
assert config.should_use_react("any query") is False
def test_mode_router_config_auto_short_query_high_confidence(self):
"""[AC-AISVC-RES-12] Auto mode with short query and high confidence should use direct."""
config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO)
assert config.should_use_react("短问题", confidence=0.8) is False
def test_mode_router_config_auto_low_confidence(self):
"""[AC-AISVC-RES-13] Auto mode with low confidence should use react."""
config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO)
assert config.should_use_react("any query", confidence=0.3) is True
def test_metadata_inference_config_determine_filter_mode(self):
"""[AC-AISVC-RES-04] Should determine filter mode based on confidence."""
config = MetadataInferenceConfig()
assert config.determine_filter_mode(0.9) == FilterMode.HARD
assert config.determine_filter_mode(0.6) == FilterMode.SOFT
assert config.determine_filter_mode(0.3) == FilterMode.NONE
assert config.determine_filter_mode(None) == FilterMode.NONE
def test_pipeline_config_default(self):
"""[AC-AISVC-RES-01] Default pipeline config should have sensible defaults."""
config = PipelineConfig()
assert config.top_k == 5
assert config.score_threshold == 0.01
assert config.two_stage_enabled is True
def test_retrieval_strategy_config_default(self):
"""[AC-AISVC-RES-01] Default strategy config should use default strategy."""
config = RetrievalStrategyConfig()
assert config.active_strategy == StrategyType.DEFAULT
assert config.grayscale.enabled is False
assert config.mode_router.runtime_mode == RuntimeMode.DIRECT
def test_retrieval_strategy_config_is_enhanced_enabled(self):
"""[AC-AISVC-RES-02] Should check if enhanced is enabled."""
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
assert config.is_enhanced_enabled("tenant_a") is True
config = RetrievalStrategyConfig(
active_strategy=StrategyType.DEFAULT,
grayscale=GrayscaleConfig(enabled=True, allowlist=["tenant_a"]),
)
assert config.is_enhanced_enabled("tenant_a") is True
assert config.is_enhanced_enabled("tenant_b") is False
def test_retrieval_strategy_config_to_dict(self):
"""[AC-AISVC-RES-01] Should convert config to dictionary."""
config = RetrievalStrategyConfig()
d = config.to_dict()
assert d["active_strategy"] == "default"
assert "grayscale" in d
assert "pipeline" in d
assert "reranker" in d
assert "mode_router" in d
def test_global_config_functions(self):
"""[AC-AISVC-RES-01] Should get and set global config."""
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
set_strategy_config(config)
retrieved = get_strategy_config()
assert retrieved.active_strategy == StrategyType.ENHANCED
set_strategy_config(RetrievalStrategyConfig())
class TestPipelineBase:
"""[AC-AISVC-RES-01~02] Tests for pipeline base classes."""
def test_metadata_filter_result_default(self):
"""[AC-AISVC-RES-04] Default metadata filter result should be empty."""
result = MetadataFilterResult()
assert result.filter_dict == {}
assert result.filter_mode == FilterMode.NONE
assert result.confidence is None
def test_pipeline_context_properties(self):
"""[AC-AISVC-RES-01] Pipeline context should expose retrieval context properties."""
retrieval_ctx = RetrievalContext(
tenant_id="tenant_1",
query="test query",
session_id="session_1",
kb_ids=["kb_1"],
)
pipeline_ctx = PipelineContext(retrieval_ctx=retrieval_ctx)
assert pipeline_ctx.tenant_id == "tenant_1"
assert pipeline_ctx.query == "test query"
assert pipeline_ctx.session_id == "session_1"
assert pipeline_ctx.kb_ids == ["kb_1"]
def test_pipeline_result_properties(self):
"""[AC-AISVC-RES-01] Pipeline result should expose retrieval result properties."""
hits = [
RetrievalHit(text="hit 1", score=0.9, source="test", metadata={}),
RetrievalHit(text="hit 2", score=0.8, source="test", metadata={}),
]
retrieval_result = RetrievalResult(hits=hits)
pipeline_result = PipelineResult(
retrieval_result=retrieval_result,
pipeline_name="test_pipeline",
)
assert pipeline_result.hits == hits
assert pipeline_result.is_empty is False
assert pipeline_result.pipeline_name == "test_pipeline"
def test_pipeline_result_is_empty(self):
"""[AC-AISVC-RES-01] Pipeline result should detect empty results."""
pipeline_result = PipelineResult(
retrieval_result=RetrievalResult(hits=[]),
)
assert pipeline_result.is_empty is True
class TestDefaultPipeline:
"""[AC-AISVC-RES-01] Tests for default pipeline."""
@pytest.fixture
def mock_retriever(self):
"""Create a mock optimized retriever."""
retriever = AsyncMock()
retriever.retrieve = AsyncMock(return_value=RetrievalResult(
hits=[
RetrievalHit(text="result 1", score=0.9, source="default", metadata={}),
],
diagnostics={"test": True},
))
retriever.health_check = AsyncMock(return_value=True)
retriever._two_stage_enabled = True
retriever._hybrid_enabled = True
return retriever
@pytest.fixture
def pipeline(self, mock_retriever):
"""Create a default pipeline with mock retriever."""
return DefaultPipeline(optimized_retriever=mock_retriever)
def test_pipeline_name(self, pipeline):
"""[AC-AISVC-RES-01] Pipeline should have correct name."""
assert pipeline.name == "default_pipeline"
def test_pipeline_description(self, pipeline):
"""[AC-AISVC-RES-01] Pipeline should have description."""
assert "默认" in pipeline.description
@pytest.mark.asyncio
async def test_retrieve(self, pipeline, mock_retriever):
"""[AC-AISVC-RES-01] Should retrieve results using optimized retriever."""
retrieval_ctx = RetrievalContext(
tenant_id="tenant_1",
query="test query",
)
pipeline_ctx = PipelineContext(retrieval_ctx=retrieval_ctx)
result = await pipeline.retrieve(pipeline_ctx)
assert result.pipeline_name == "default_pipeline"
assert len(result.hits) == 1
assert result.diagnostics["retriever"] == "OptimizedRetriever"
mock_retriever.retrieve.assert_called_once()
@pytest.mark.asyncio
async def test_retrieve_with_metadata_filter(self, pipeline, mock_retriever):
"""[AC-AISVC-RES-04] Should apply metadata filter."""
retrieval_ctx = RetrievalContext(
tenant_id="tenant_1",
query="test query",
)
metadata_filter = MetadataFilterResult(
filter_dict={"grade": "初一"},
filter_mode=FilterMode.HARD,
)
pipeline_ctx = PipelineContext(
retrieval_ctx=retrieval_ctx,
metadata_filter=metadata_filter,
)
result = await pipeline.retrieve(pipeline_ctx)
assert result.metadata_filter_applied is True
call_args = mock_retriever.retrieve.call_args[0][0]
assert call_args.metadata_filter == {"grade": "初一"}
@pytest.mark.asyncio
async def test_health_check(self, pipeline, mock_retriever):
"""[AC-AISVC-RES-01] Should check health."""
result = await pipeline.health_check()
assert result is True
mock_retriever.health_check.assert_called_once()
class TestEnhancedPipeline:
"""[AC-AISVC-RES-02] Tests for enhanced pipeline."""
@pytest.fixture
def mock_qdrant_client(self):
"""Create a mock Qdrant client."""
client = AsyncMock()
client.search = AsyncMock(return_value=[
{"id": "1", "score": 0.9, "payload": {"text": "result 1"}},
])
client.get_client = AsyncMock()
return client
@pytest.fixture
def mock_embedding_provider(self):
"""Create a mock embedding provider."""
provider = AsyncMock()
provider.embed_query = AsyncMock()
provider.embed_query.return_value = MagicMock(
embedding_full=[0.1] * 768,
)
provider.embed = AsyncMock(return_value=[0.1] * 768)
return provider
@pytest.fixture
def pipeline(self, mock_qdrant_client, mock_embedding_provider):
"""Create an enhanced pipeline with mocks."""
pipeline = EnhancedPipeline(qdrant_client=mock_qdrant_client)
pipeline._embedding_provider = mock_embedding_provider
return pipeline
def test_pipeline_name(self, pipeline):
"""[AC-AISVC-RES-02] Pipeline should have correct name."""
assert pipeline.name == "enhanced_pipeline"
def test_pipeline_description(self, pipeline):
"""[AC-AISVC-RES-02] Pipeline should have description."""
assert "增强" in pipeline.description
@pytest.mark.asyncio
async def test_retrieve_basic(self, pipeline):
"""[AC-AISVC-RES-02] Should retrieve results using hybrid search."""
retrieval_ctx = RetrievalContext(
tenant_id="tenant_1",
query="test query",
)
pipeline_ctx = PipelineContext(retrieval_ctx=retrieval_ctx)
result = await pipeline.retrieve(pipeline_ctx)
assert result.pipeline_name == "enhanced_pipeline"
assert result.diagnostics is not None
class TestStrategyRouter:
"""[AC-AISVC-RES-01~03] Tests for strategy router."""
@pytest.fixture
def mock_default_pipeline(self):
"""Create a mock default pipeline."""
pipeline = AsyncMock(spec=DefaultPipeline)
pipeline.name = "default_pipeline"
pipeline.retrieve = AsyncMock(return_value=PipelineResult(
retrieval_result=RetrievalResult(hits=[]),
pipeline_name="default_pipeline",
))
return pipeline
@pytest.fixture
def mock_enhanced_pipeline(self):
"""Create a mock enhanced pipeline."""
pipeline = AsyncMock(spec=EnhancedPipeline)
pipeline.name = "enhanced_pipeline"
pipeline.retrieve = AsyncMock(return_value=PipelineResult(
retrieval_result=RetrievalResult(hits=[]),
pipeline_name="enhanced_pipeline",
))
return pipeline
@pytest.fixture
def router(self, mock_default_pipeline, mock_enhanced_pipeline):
"""Create a strategy router with mock pipelines."""
config = RetrievalStrategyConfig()
return StrategyRouter(
config=config,
default_pipeline=mock_default_pipeline,
enhanced_pipeline=mock_enhanced_pipeline,
)
def test_route_default_strategy(self, router):
"""[AC-AISVC-RES-01] Should route to default strategy by default."""
import asyncio
decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_1"))
assert decision.strategy == StrategyType.DEFAULT
assert decision.reason == "default_strategy"
def test_route_enhanced_strategy(self, mock_default_pipeline, mock_enhanced_pipeline):
"""[AC-AISVC-RES-02] Should route to enhanced strategy when configured."""
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
router = StrategyRouter(
config=config,
default_pipeline=mock_default_pipeline,
enhanced_pipeline=mock_enhanced_pipeline,
)
import asyncio
decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_1"))
assert decision.strategy == StrategyType.ENHANCED
assert decision.reason == "active_strategy=enhanced"
def test_route_grayscale_allowlist(self, mock_default_pipeline, mock_enhanced_pipeline):
"""[AC-AISVC-RES-03] Should route to enhanced for allowlist tenants."""
config = RetrievalStrategyConfig(
active_strategy=StrategyType.DEFAULT,
grayscale=GrayscaleConfig(enabled=True, allowlist=["tenant_a"]),
)
router = StrategyRouter(
config=config,
default_pipeline=mock_default_pipeline,
enhanced_pipeline=mock_enhanced_pipeline,
)
import asyncio
decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_a"))
assert decision.strategy == StrategyType.ENHANCED
assert decision.grayscale_hit is True
decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_b"))
assert decision.strategy == StrategyType.DEFAULT
def test_update_config(self, router):
"""[AC-AISVC-RES-02] Should update config."""
new_config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
router.update_config(new_config)
assert router.get_config().active_strategy == StrategyType.ENHANCED
class TestModeRouter:
"""[AC-AISVC-RES-09~15] Tests for mode router."""
@pytest.fixture
def router(self):
"""Create a mode router."""
return ModeRouter()
def test_decide_react_mode(self):
"""[AC-AISVC-RES-10] Should decide react when configured."""
config = ModeRouterConfig(runtime_mode=RuntimeMode.REACT)
router = ModeRouter(config)
decision = router.decide("any query")
assert decision.mode == RuntimeMode.REACT
assert decision.reason == "runtime_mode=react"
def test_decide_direct_mode(self, router):
"""[AC-AISVC-RES-09] Should decide direct when configured."""
decision = router.decide("any query")
assert decision.mode == RuntimeMode.DIRECT
assert decision.reason == "runtime_mode=direct"
def test_decide_auto_short_query_high_confidence(self):
"""[AC-AISVC-RES-12] Auto with short query and high confidence should use direct."""
config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO)
router = ModeRouter(config)
decision = router.decide("短问题", confidence=0.8)
assert decision.mode == RuntimeMode.DIRECT
def test_decide_auto_low_confidence(self):
"""[AC-AISVC-RES-13] Auto with low confidence should use react."""
config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO)
router = ModeRouter(config)
decision = router.decide("any query", confidence=0.3)
assert decision.mode == RuntimeMode.REACT
def test_should_fallback_to_react_empty_results(self, router):
"""[AC-AISVC-RES-14] Should fallback to react on empty results."""
result = PipelineResult(retrieval_result=RetrievalResult(hits=[]))
assert router.should_fallback_to_react(result) is True
def test_should_fallback_to_react_low_score(self, router):
"""[AC-AISVC-RES-14] Should fallback to react on low score."""
result = PipelineResult(
retrieval_result=RetrievalResult(
hits=[RetrievalHit(text="test", score=0.1, source="test", metadata={})],
),
)
assert router.should_fallback_to_react(result) is True
def test_should_not_fallback_to_react_disabled(self):
"""[AC-AISVC-RES-14] Should not fallback when disabled."""
config = ModeRouterConfig(direct_fallback_on_low_confidence=False)
router = ModeRouter(config)
result = PipelineResult(retrieval_result=RetrievalResult(hits=[]))
assert router.should_fallback_to_react(result) is False
class TestRollbackManager:
"""[AC-AISVC-RES-07] Tests for rollback manager."""
@pytest.fixture
def manager(self):
"""Create a rollback manager."""
return RollbackManager()
def test_rollback_from_enhanced(self, manager):
"""[AC-AISVC-RES-07] Should rollback from enhanced to default."""
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
manager.update_config(config)
result = manager.rollback(
trigger=RollbackTrigger.MANUAL,
reason="Testing rollback",
)
assert result.success is True
assert result.previous_strategy == StrategyType.ENHANCED
assert result.current_strategy == StrategyType.DEFAULT
assert result.audit_log is not None
def test_rollback_already_default(self, manager):
"""[AC-AISVC-RES-07] Should not rollback when already on default."""
result = manager.rollback(
trigger=RollbackTrigger.MANUAL,
reason="Testing rollback",
)
assert result.success is False
assert result.reason == "Already on default strategy"
def test_check_and_rollback_latency(self, manager):
"""[AC-AISVC-RES-08] Should rollback on high latency."""
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
manager.update_config(config)
result = manager.check_and_rollback(
metrics={"latency_ms": 3000.0},
tenant_id="tenant_1",
)
assert result is not None
assert result.trigger == RollbackTrigger.PERFORMANCE
def test_check_and_rollback_error_rate(self, manager):
"""[AC-AISVC-RES-08] Should rollback on high error rate."""
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
manager.update_config(config)
result = manager.check_and_rollback(
metrics={"error_rate": 0.1},
tenant_id="tenant_1",
)
assert result is not None
assert result.trigger == RollbackTrigger.ERROR
def test_check_and_rollback_ok(self, manager):
"""[AC-AISVC-RES-08] Should not rollback when metrics are ok."""
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
manager.update_config(config)
result = manager.check_and_rollback(
metrics={"latency_ms": 100.0, "error_rate": 0.01},
tenant_id="tenant_1",
)
assert result is None
def test_get_audit_logs(self, manager):
"""[AC-AISVC-RES-07] Should get audit logs."""
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
manager.update_config(config)
manager.rollback(trigger=RollbackTrigger.MANUAL, reason="Test")
logs = manager.get_audit_logs()
assert len(logs) == 1
assert logs[0].action == "rollback"
def test_record_audit(self, manager):
"""[AC-AISVC-RES-07] Should record audit log."""
log = manager.record_audit(
action="test_action",
details={"reason": "Testing"},
tenant_id="tenant_1",
)
assert log.action == "test_action"
assert log.tenant_id == "tenant_1"
class TestSingletonInstances:
"""Tests for singleton instance getters."""
def test_get_mode_router_singleton(self):
"""Should return same mode router instance."""
from app.services.retrieval.strategy.mode_router import _mode_router
import app.services.retrieval.strategy.mode_router as module
module._mode_router = None
router1 = get_mode_router()
router2 = get_mode_router()
assert router1 is router2
def test_get_rollback_manager_singleton(self):
"""Should return same rollback manager instance."""
from app.services.retrieval.strategy.rollback_manager import _rollback_manager
import app.services.retrieval.strategy.rollback_manager as module
module._rollback_manager = None
manager1 = get_rollback_manager()
manager2 = get_rollback_manager()
assert manager1 is manager2

View File

@ -0,0 +1,167 @@
---
context:
module: "ai-service"
feature: "AISVC-RES"
status: "✅已完成"
version: "0.9.0"
active_ac_range: "AC-AISVC-RES-01~15"
spec_references:
requirements: "spec/ai-service/iterations/v0.9.0-retrieval-embedding-strategy/requirements.md"
design: "spec/ai-service/iterations/v0.9.0-retrieval-embedding-strategy/design.md"
tasks: "spec/ai-service/iterations/v0.9.0-retrieval-embedding-strategy/tasks.md"
openapi_provider: "spec/ai-service/iterations/v0.9.0-retrieval-embedding-strategy/openapi.provider.yaml"
active_version: "0.1.0"
overall_progress:
- "[x] Phase 1: Schema与数据模型定义 (100%) [API与校验]"
- "[x] Phase 2: 策略服务层实现 (100%) [策略层与配置]"
- "[x] Phase 3: 审计日志与指标埋点 (100%) [观测与灰度验证]"
- "[x] Phase 4: API端点实现 (100%) [API与校验]"
- "[x] Phase 5: 单元测试与验证 (100%) [验收]"
- "[x] Phase 6: 检索策略Pipeline实现 (100%) [检索与嵌入策略]"
current_phase:
goal: "检索与嵌入策略化改造已完成"
sub_tasks:
- "[x] 创建 Schema 模型 (app/schemas/retrieval_strategy.py)"
- "[x] 创建策略服务层 (app/services/retrieval/strategy_service.py)"
- "[x] 创建审计日志服务 (app/services/retrieval/strategy_audit.py)"
- "[x] 创建指标埋点服务 (app/services/retrieval/strategy_metrics.py)"
- "[x] 创建 API 端点 (app/api/admin/retrieval_strategy.py)"
- "[x] 创建策略配置模型 (app/services/retrieval/strategy/config.py)"
- "[x] 实现 DefaultPipeline复用现有逻辑"
- "[x] 实现 EnhancedPipeline新端到端流程"
- "[x] 实现元数据推断统一入口 (MetadataInferenceService)"
- "[x] 实现 StrategyRouter 和 ModeRouter"
- "[x] 实现 Dense + Keyword + RRF 组合检索"
- "[x] 实现可选重排与降级开关"
- "[x] 实现 RollbackManager回退与审计"
- "[x] 实现策略 API 接口"
- "[x] 创建单元测试 (tests/test_retrieval_strategy_v2.py)"
- "[x] 运行单元测试验证 (51 passed)"
next_action:
immediate: "任务已完成,可进行集成测试"
details:
file: "ai-service/app/services/retrieval/strategy/__init__.py:1"
action: "模块已完整实现,可通过 API 接口测试策略切换功能"
reference: "http://localhost:8000/docs"
constraints: "新策略可配置启用,不影响默认策略"
technical_context:
module_structure: |
ai-service/app/
├── api/admin/retrieval_strategy.py (API端点)
├── schemas/retrieval_strategy.py (Schema模型)
└── services/retrieval/
├── strategy_service.py (策略服务)
├── strategy_audit.py (审计日志)
├── strategy_metrics.py (指标埋点)
└── strategy/ (新增 - 策略模块)
├── __init__.py (模块导出)
├── config.py (策略配置模型)
├── pipeline_base.py (Pipeline基类)
├── default_pipeline.py (默认策略Pipeline)
├── enhanced_pipeline.py (增强策略Pipeline)
├── metadata_inference.py (元数据推断统一入口)
├── strategy_router.py (策略路由器)
├── mode_router.py (模式路由器)
└── rollback_manager.py (回退管理器)
└── tests/
├── test_retrieval_strategy.py (单元测试 - 原有)
└── test_retrieval_strategy_v2.py (单元测试 - 新增)
key_decisions:
- decision: "使用内存存储策略状态,后续可扩展为持久化"
reason: "快速实现,满足灰度验证需求"
impact: "服务重启后策略状态重置为默认值"
- decision: "审计日志使用结构化日志记录"
reason: "与现有日志体系一致,便于检索"
impact: "需要配置日志聚合系统收集审计日志"
- decision: "API与现有strategy_router.py互补"
reason: "strategy_router.py负责检索路由逻辑新增的API负责策略管理"
impact: "两者协同工作API提供管理界面"
- decision: "DefaultPipeline 复用现有 OptimizedRetriever 逻辑"
reason: "保持线上行为不变,最小化改动风险"
impact: "新策略与旧策略完全隔离,可独立灰度"
- decision: "EnhancedPipeline 实现新端到端流程"
reason: "支持 Dense + Keyword + RRF 组合检索,可选重排"
impact: "需要配置启用,不影响默认策略"
- decision: "元数据推断统一入口处理 hard/soft filter"
reason: "新旧策略共享同一推断逻辑,确保一致性"
impact: "置信度高用硬过滤,置信度低用软过滤/加权"
code_snippets: |
# 使用示例
from app.services.retrieval.strategy import (
get_strategy_router,
get_mode_router,
get_rollback_manager,
)
# 策略路由
router = get_strategy_router()
decision = await router.route(tenant_id, user_id)
result = await decision.pipeline.retrieve(ctx)
# 模式路由
mode_router = get_mode_router()
mode_decision = mode_router.decide(query, confidence=0.8)
# 回退管理
rollback = get_rollback_manager()
rollback.rollback(trigger="manual", reason="测试回退")
session_history:
- session: "Session #1 (2026-03-10)"
completed:
- "创建 Schema 模型 (app/schemas/retrieval_strategy.py)"
- "创建策略服务层 (app/services/retrieval/strategy_service.py)"
- "创建审计日志服务 (app/services/retrieval/strategy_audit.py)"
- "创建指标埋点服务 (app/services/retrieval/strategy_metrics.py)"
- "创建 API 端点 (app/api/admin/retrieval_strategy.py)"
- "更新 admin __init__.py 注册新路由"
- "更新 main.py 注册新路由"
- "创建单元测试 (tests/test_retrieval_strategy.py)"
- "运行单元测试验证 (46 passed)"
changes:
- "新增: ai-service/app/schemas/retrieval_strategy.py"
- "新增: ai-service/app/services/retrieval/strategy_service.py"
- "新增: ai-service/app/services/retrieval/strategy_audit.py"
- "新增: ai-service/app/services/retrieval/strategy_metrics.py"
- "新增: ai-service/app/api/admin/retrieval_strategy.py"
- "修改: ai-service/app/api/admin/__init__.py"
- "修改: ai-service/app/main.py"
- "新增: ai-service/tests/test_retrieval_strategy.py"
status: "✅ 任务完成"
- session: "Session #2 (2026-03-10) - 检索策略Pipeline实现"
completed:
- "实现策略配置模型 (config.py)"
- "实现 Pipeline 基类 (pipeline_base.py)"
- "实现 DefaultPipeline复用现有逻辑"
- "实现 EnhancedPipeline新端到端流程"
- "实现 MetadataInferenceService"
- "实现 StrategyRouter"
- "实现 ModeRouter"
- "实现 RollbackManager"
- "更新模块导出 (__init__.py)"
- "创建单元测试 (tests/test_retrieval_strategy_v2.py)"
- "运行单元测试验证 (51 passed)"
changes:
- "新增: ai-service/app/services/retrieval/strategy/__init__.py"
- "新增: ai-service/app/services/retrieval/strategy/config.py"
- "新增: ai-service/app/services/retrieval/strategy/pipeline_base.py"
- "新增: ai-service/app/services/retrieval/strategy/default_pipeline.py"
- "新增: ai-service/app/services/retrieval/strategy/enhanced_pipeline.py"
- "新增: ai-service/app/services/retrieval/strategy/metadata_inference.py"
- "新增: ai-service/app/services/retrieval/strategy/strategy_router.py"
- "新增: ai-service/app/services/retrieval/strategy/mode_router.py"
- "新增: ai-service/app/services/retrieval/strategy/rollback_manager.py"
- "新增: ai-service/tests/test_retrieval_strategy_v2.py"
status: "✅ 任务完成"
startup_guide:
- "Step 1: 读取本进度文档(了解当前位置与下一步)"
- "Step 2: 读取 spec_references 中定义的模块规范(了解业务与接口约束)"
- "Step 3: 通过 API 接口测试策略切换功能"
- "Step 4: 运行单元测试验证: pytest tests/test_retrieval_strategy_v2.py -v"
---