From 7027097513d41483e8da6a47b7eaa6277c4780f5 Mon Sep 17 00:00:00 2001 From: MerCry Date: Tue, 10 Mar 2026 20:50:16 +0800 Subject: [PATCH] =?UTF-8?q?[AC-AISVC-RES-01~15]=20feat(retrieval):=20?= =?UTF-8?q?=E5=AE=9E=E7=8E=B0=E6=A3=80=E7=B4=A2=E7=AD=96=E7=95=A5Pipeline?= =?UTF-8?q?=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - 新增策略配置模型 (config.py) - GrayscaleConfig: 灰度发布配置 - ModeRouterConfig: 模式路由配置 - MetadataInferenceConfig: 元数据推断配置 - 新增 Pipeline 实现 - DefaultPipeline: 复用现有 OptimizedRetriever 逻辑 - EnhancedPipeline: Dense + Keyword + RRF 组合检索 - 新增路由器 - StrategyRouter: 策略路由器(default/enhanced) - ModeRouter: 模式路由器(direct/react/auto) - 新增 RollbackManager: 回退与审计管理器 - 新增 MetadataInferenceService: 元数据推断统一入口 - 新增单元测试 (51 passed) --- .../services/retrieval/strategy/__init__.py | 102 +++ .../app/services/retrieval/strategy/config.py | 201 ++++++ .../retrieval/strategy/default_pipeline.py | 117 ++++ .../retrieval/strategy/enhanced_pipeline.py | 364 ++++++++++ .../retrieval/strategy/metadata_inference.py | 136 ++++ .../retrieval/strategy/mode_router.py | 118 ++++ .../retrieval/strategy/pipeline_base.py | 116 ++++ .../retrieval/strategy/retrieval_strategy.py | 301 ++++++++ .../retrieval/strategy/rollback_manager.py | 192 ++++++ .../retrieval/strategy/strategy_router.py | 109 +++ .../tests/test_retrieval_strategy_v2.py | 645 ++++++++++++++++++ .../progress/ai-service-AISVC-RES-progress.md | 167 +++++ 12 files changed, 2568 insertions(+) create mode 100644 ai-service/app/services/retrieval/strategy/__init__.py create mode 100644 ai-service/app/services/retrieval/strategy/config.py create mode 100644 ai-service/app/services/retrieval/strategy/default_pipeline.py create mode 100644 ai-service/app/services/retrieval/strategy/enhanced_pipeline.py create mode 100644 ai-service/app/services/retrieval/strategy/metadata_inference.py create mode 100644 ai-service/app/services/retrieval/strategy/mode_router.py create mode 100644 ai-service/app/services/retrieval/strategy/pipeline_base.py create mode 100644 ai-service/app/services/retrieval/strategy/retrieval_strategy.py create mode 100644 ai-service/app/services/retrieval/strategy/rollback_manager.py create mode 100644 ai-service/app/services/retrieval/strategy/strategy_router.py create mode 100644 ai-service/tests/test_retrieval_strategy_v2.py create mode 100644 docs/progress/ai-service-AISVC-RES-progress.md diff --git a/ai-service/app/services/retrieval/strategy/__init__.py b/ai-service/app/services/retrieval/strategy/__init__.py new file mode 100644 index 0000000..e3864f0 --- /dev/null +++ b/ai-service/app/services/retrieval/strategy/__init__.py @@ -0,0 +1,102 @@ +""" +Retrieval Strategy Module for AI Service. +[AC-AISVC-RES-01~15] 策略化检索与嵌入模块。 + +核心组件: +- RetrievalStrategyConfig: 策略配置模型 +- BasePipeline: Pipeline 抽象基类 +- DefaultPipeline: 默认策略(复用现有逻辑) +- EnhancedPipeline: 增强策略(新端到端流程) +- MetadataInferenceService: 元数据推断统一入口 +- StrategyRouter: 策略路由器 +- ModeRouter: 模式路由器(direct/react/auto) +- RollbackManager: 回退管理器 +""" + +from app.services.retrieval.strategy.config import ( + FilterMode, + GrayscaleConfig, + HybridRetrievalConfig, + MetadataInferenceConfig, + ModeRouterConfig, + PipelineConfig, + RerankerConfig, + RetrievalStrategyConfig, + RuntimeMode, + StrategyType, + get_strategy_config, + set_strategy_config, +) +from app.services.retrieval.strategy.default_pipeline import ( + DefaultPipeline, + get_default_pipeline, +) +from app.services.retrieval.strategy.enhanced_pipeline import ( + EnhancedPipeline, + get_enhanced_pipeline, +) +from app.services.retrieval.strategy.metadata_inference import ( + InferenceContext, + InferenceResult, + MetadataInferenceService, +) +from app.services.retrieval.strategy.mode_router import ( + ModeDecision, + ModeRouter, + get_mode_router, +) +from app.services.retrieval.strategy.pipeline_base import ( + BasePipeline, + MetadataFilterResult, + PipelineContext, + PipelineResult, +) +from app.services.retrieval.strategy.rollback_manager import ( + AuditLog, + RollbackManager, + RollbackResult, + RollbackTrigger, + get_rollback_manager, +) +from app.services.retrieval.strategy.strategy_router import ( + RoutingDecision, + StrategyRouter, + get_strategy_router, +) + +__all__ = [ + "BasePipeline", + "PipelineContext", + "PipelineResult", + "MetadataFilterResult", + "DefaultPipeline", + "get_default_pipeline", + "EnhancedPipeline", + "get_enhanced_pipeline", + "RetrievalStrategyConfig", + "GrayscaleConfig", + "PipelineConfig", + "RerankerConfig", + "ModeRouterConfig", + "HybridRetrievalConfig", + "MetadataInferenceConfig", + "StrategyType", + "FilterMode", + "RuntimeMode", + "get_strategy_config", + "set_strategy_config", + "MetadataInferenceService", + "InferenceContext", + "InferenceResult", + "StrategyRouter", + "RoutingDecision", + "get_strategy_router", + "ModeRouter", + "ModeDecision", + "get_mode_router", + "RollbackManager", + "RollbackResult", + "RollbackTrigger", + "AuditLog", + "get_rollback_manager", +] diff --git a/ai-service/app/services/retrieval/strategy/config.py b/ai-service/app/services/retrieval/strategy/config.py new file mode 100644 index 0000000..0d04fa6 --- /dev/null +++ b/ai-service/app/services/retrieval/strategy/config.py @@ -0,0 +1,201 @@ +""" +Retrieval Strategy Configuration. +[AC-AISVC-RES-01~15] 检索策略配置模型。 +""" + +import random +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + + +class StrategyType(str, Enum): + """策略类型。""" + DEFAULT = "default" + ENHANCED = "enhanced" + + +class RuntimeMode(str, Enum): + """运行时模式。""" + DIRECT = "direct" + REACT = "react" + AUTO = "auto" + + +class FilterMode(str, Enum): + """过滤模式。""" + HARD = "hard" + SOFT = "soft" + NONE = "none" + + +@dataclass +class GrayscaleConfig: + """灰度发布配置。【AC-AISVC-RES-03】""" + enabled: bool = False + percentage: float = 0.0 + allowlist: list[str] = field(default_factory=list) + + def should_use_enhanced(self, tenant_id: str, user_id: str | None = None) -> bool: + """判断是否应该使用增强策略。""" + if not self.enabled: + return False + + if tenant_id in self.allowlist or (user_id and user_id in self.allowlist): + return True + + return random.random() * 100 < self.percentage + + +@dataclass +class HybridRetrievalConfig: + """混合检索配置。""" + dense_weight: float = 0.7 + keyword_weight: float = 0.3 + rrf_k: int = 60 + enable_keyword: bool = True + keyword_top_k_multiplier: int = 2 + + +@dataclass +class RerankerConfig: + """重排器配置。【AC-AISVC-RES-08】""" + enabled: bool = False + model: str = "cross-encoder" + top_k_after_rerank: int = 5 + min_score_threshold: float = 0.3 + + +@dataclass +class ModeRouterConfig: + """模式路由配置。【AC-AISVC-RES-09~15】""" + runtime_mode: RuntimeMode = RuntimeMode.DIRECT + react_trigger_confidence_threshold: float = 0.6 + react_trigger_complexity_score: float = 0.5 + react_max_steps: int = 5 + direct_fallback_on_low_confidence: bool = True + short_query_threshold: int = 20 + + def should_use_react( + self, + query: str, + confidence: float | None = None, + complexity_score: float | None = None, + ) -> bool: + """判断是否应该使用 ReAct 模式。【AC-AISVC-RES-11~13】""" + if self.runtime_mode == RuntimeMode.REACT: + return True + if self.runtime_mode == RuntimeMode.DIRECT: + return False + + if len(query) <= self.short_query_threshold and confidence and confidence >= self.react_trigger_confidence_threshold: + return False + + if complexity_score and complexity_score >= self.react_trigger_complexity_score: + return True + + if confidence and confidence < self.react_trigger_confidence_threshold: + return True + + return False + + +@dataclass +class MetadataInferenceConfig: + """元数据推断配置。""" + enabled: bool = True + confidence_high_threshold: float = 0.8 + confidence_low_threshold: float = 0.5 + default_filter_mode: FilterMode = FilterMode.SOFT + cache_ttl_seconds: int = 300 + + def determine_filter_mode(self, confidence: float | None) -> FilterMode: + """根据置信度确定过滤模式。""" + if confidence is None: + return FilterMode.NONE + if confidence >= self.confidence_high_threshold: + return FilterMode.HARD + if confidence >= self.confidence_low_threshold: + return FilterMode.SOFT + return FilterMode.NONE + + +@dataclass +class PipelineConfig: + """Pipeline 配置。""" + top_k: int = 5 + score_threshold: float = 0.01 + min_hits: int = 1 + two_stage_enabled: bool = True + two_stage_expand_factor: int = 10 + hybrid: HybridRetrievalConfig = field(default_factory=HybridRetrievalConfig) + + +@dataclass +class RetrievalStrategyConfig: + """检索策略顶层配置。【AC-AISVC-RES-01~15】""" + active_strategy: StrategyType = StrategyType.DEFAULT + grayscale: GrayscaleConfig = field(default_factory=GrayscaleConfig) + pipeline: PipelineConfig = field(default_factory=PipelineConfig) + reranker: RerankerConfig = field(default_factory=RerankerConfig) + mode_router: ModeRouterConfig = field(default_factory=ModeRouterConfig) + metadata_inference: MetadataInferenceConfig = field(default_factory=MetadataInferenceConfig) + performance_thresholds: dict[str, float] = field(default_factory=lambda: { + "max_latency_ms": 2000.0, + "min_success_rate": 0.95, + "max_error_rate": 0.05, + }) + + def is_enhanced_enabled(self, tenant_id: str, user_id: str | None = None) -> bool: + """判断是否启用增强策略。""" + if self.active_strategy == StrategyType.ENHANCED: + return True + return self.grayscale.should_use_enhanced(tenant_id, user_id) + + def to_dict(self) -> dict[str, Any]: + """转换为字典。""" + return { + "active_strategy": self.active_strategy.value, + "grayscale": { + "enabled": self.grayscale.enabled, + "percentage": self.grayscale.percentage, + "allowlist": self.grayscale.allowlist, + }, + "pipeline": { + "top_k": self.pipeline.top_k, + "score_threshold": self.pipeline.score_threshold, + "min_hits": self.pipeline.min_hits, + "two_stage_enabled": self.pipeline.two_stage_enabled, + }, + "reranker": { + "enabled": self.reranker.enabled, + "model": self.reranker.model, + "top_k_after_rerank": self.reranker.top_k_after_rerank, + }, + "mode_router": { + "runtime_mode": self.mode_router.runtime_mode.value, + "react_trigger_confidence_threshold": self.mode_router.react_trigger_confidence_threshold, + }, + "metadata_inference": { + "enabled": self.metadata_inference.enabled, + "confidence_high_threshold": self.metadata_inference.confidence_high_threshold, + }, + "performance_thresholds": self.performance_thresholds, + } + + +_global_config: RetrievalStrategyConfig | None = None + + +def get_strategy_config() -> RetrievalStrategyConfig: + """获取全局策略配置。""" + global _global_config + if _global_config is None: + _global_config = RetrievalStrategyConfig() + return _global_config + + +def set_strategy_config(config: RetrievalStrategyConfig) -> None: + """设置全局策略配置。""" + global _global_config + _global_config = config diff --git a/ai-service/app/services/retrieval/strategy/default_pipeline.py b/ai-service/app/services/retrieval/strategy/default_pipeline.py new file mode 100644 index 0000000..e4b8c38 --- /dev/null +++ b/ai-service/app/services/retrieval/strategy/default_pipeline.py @@ -0,0 +1,117 @@ +""" +Default Pipeline. +[AC-AISVC-RES-01] 默认策略 Pipeline,复用现有逻辑。 +""" + +import logging +import time +from typing import Any + +from app.services.retrieval.base import RetrievalContext, RetrievalResult +from app.services.retrieval.optimized_retriever import OptimizedRetriever, get_optimized_retriever +from app.services.retrieval.strategy.config import PipelineConfig +from app.services.retrieval.strategy.pipeline_base import ( + BasePipeline, + PipelineContext, + PipelineResult, +) + +logger = logging.getLogger(__name__) + + +class DefaultPipeline(BasePipeline): + """ + 默认策略 Pipeline。【AC-AISVC-RES-01】 + + 复用现有 OptimizedRetriever 逻辑,保持线上行为不变。 + """ + + def __init__( + self, + config: PipelineConfig | None = None, + optimized_retriever: OptimizedRetriever | None = None, + ): + self._config = config or PipelineConfig() + self._optimized_retriever = optimized_retriever + + @property + def name(self) -> str: + return "default_pipeline" + + @property + def description(self) -> str: + return "默认检索策略,复用现有 OptimizedRetriever 逻辑。" + + async def _get_retriever(self) -> OptimizedRetriever: + if self._optimized_retriever is None: + self._optimized_retriever = await get_optimized_retriever() + return self._optimized_retriever + + async def retrieve(self, ctx: PipelineContext) -> PipelineResult: + """执行默认检索流程。【AC-AISVC-RES-01】""" + start_time = time.time() + + logger.info( + f"[DefaultPipeline] Starting retrieval: tenant={ctx.tenant_id}, " + f"query={ctx.query[:50]}..." + ) + + try: + retriever = await self._get_retriever() + + metadata_filter = None + if ctx.metadata_filter: + metadata_filter = ctx.metadata_filter.filter_dict + + retrieval_ctx = RetrievalContext( + tenant_id=ctx.tenant_id, + query=ctx.query, + session_id=ctx.session_id, + metadata_filter=metadata_filter, + kb_ids=ctx.kb_ids, + ) + + result = await retriever.retrieve(retrieval_ctx) + + latency_ms = (time.time() - start_time) * 1000 + + logger.info( + f"[DefaultPipeline] Retrieval completed: hits={len(result.hits)}, " + f"latency_ms={latency_ms:.2f}" + ) + + return PipelineResult( + retrieval_result=result, + pipeline_name=self.name, + metadata_filter_applied=metadata_filter is not None, + latency_ms=latency_ms, + diagnostics={ + "retriever": "OptimizedRetriever", + **(result.diagnostics or {}), + }, + ) + + except Exception as e: + latency_ms = (time.time() - start_time) * 1000 + logger.error(f"[DefaultPipeline] Retrieval error: {e}", exc_info=True) + return self._create_empty_result(ctx, error=str(e), latency_ms=latency_ms) + + async def health_check(self) -> bool: + """健康检查。""" + try: + retriever = await self._get_retriever() + return await retriever.health_check() + except Exception as e: + logger.error(f"[DefaultPipeline] Health check failed: {e}") + return False + + +_default_pipeline: DefaultPipeline | None = None + + +async def get_default_pipeline() -> DefaultPipeline: + """获取 DefaultPipeline 单例。""" + global _default_pipeline + if _default_pipeline is None: + _default_pipeline = DefaultPipeline() + return _default_pipeline diff --git a/ai-service/app/services/retrieval/strategy/enhanced_pipeline.py b/ai-service/app/services/retrieval/strategy/enhanced_pipeline.py new file mode 100644 index 0000000..1270c80 --- /dev/null +++ b/ai-service/app/services/retrieval/strategy/enhanced_pipeline.py @@ -0,0 +1,364 @@ +""" +Enhanced Pipeline. +[AC-AISVC-RES-02] 增强策略 Pipeline,新端到端流程。 +""" + +import asyncio +import logging +import re +import time +from dataclasses import dataclass +from typing import Any + +import numpy as np + +from app.core.config import get_settings +from app.core.qdrant_client import QdrantClient, get_qdrant_client +from app.services.embedding.nomic_provider import NomicEmbeddingProvider +from app.services.retrieval.base import RetrievalHit, RetrievalResult +from app.services.retrieval.optimized_retriever import RRFCombiner +from app.services.retrieval.strategy.config import ( + HybridRetrievalConfig, + PipelineConfig, + RerankerConfig, +) +from app.services.retrieval.strategy.pipeline_base import ( + BasePipeline, + PipelineContext, + PipelineResult, +) + +logger = logging.getLogger(__name__) +settings = get_settings() + + +@dataclass +class RetrievalCandidate: + """检索候选结果。""" + id: str + text: str + score: float + vector_score: float = 0.0 + keyword_score: float = 0.0 + metadata: dict[str, Any] = None + + def __post_init__(self): + if self.metadata is None: + self.metadata = {} + + +class EnhancedPipeline(BasePipeline): + """ + 增强策略 Pipeline。【AC-AISVC-RES-02】 + + 新端到端流程: + 1. Dense 向量检索 + 2. Keyword 关键词检索 + 3. RRF 融合排序 + 4. 可选重排 + """ + + def __init__( + self, + config: PipelineConfig | None = None, + reranker_config: RerankerConfig | None = None, + qdrant_client: QdrantClient | None = None, + ): + self._config = config or PipelineConfig() + self._reranker_config = reranker_config or RerankerConfig() + self._qdrant_client = qdrant_client + self._rrf_combiner = RRFCombiner(k=self._config.hybrid.rrf_k) + self._embedding_provider: NomicEmbeddingProvider | None = None + + @property + def name(self) -> str: + return "enhanced_pipeline" + + @property + def description(self) -> str: + return "增强检索策略,支持 Dense + Keyword + RRF 组合检索。" + + async def _get_client(self) -> QdrantClient: + if self._qdrant_client is None: + self._qdrant_client = await get_qdrant_client() + return self._qdrant_client + + async def _get_embedding_provider(self) -> NomicEmbeddingProvider: + if self._embedding_provider is None: + from app.services.embedding.factory import get_embedding_config_manager + manager = get_embedding_config_manager() + provider = await manager.get_provider() + if isinstance(provider, NomicEmbeddingProvider): + self._embedding_provider = provider + else: + self._embedding_provider = NomicEmbeddingProvider( + base_url=settings.ollama_base_url, + model=settings.ollama_embedding_model, + dimension=settings.qdrant_vector_size, + ) + return self._embedding_provider + + async def retrieve(self, ctx: PipelineContext) -> PipelineResult: + """执行增强检索流程。【AC-AISVC-RES-02】""" + start_time = time.time() + + logger.info( + f"[EnhancedPipeline] Starting retrieval: tenant={ctx.tenant_id}, " + f"query={ctx.query[:50]}..." + ) + + try: + provider = await self._get_embedding_provider() + embedding_result = await provider.embed_query(ctx.query) + + candidates = await self._hybrid_retrieve( + tenant_id=ctx.tenant_id, + query=ctx.query, + embedding_result=embedding_result, + metadata_filter=ctx.metadata_filter.filter_dict if ctx.metadata_filter else None, + kb_ids=ctx.kb_ids, + ) + + if self._reranker_config.enabled and ctx.use_reranker: + candidates = await self._rerank( + candidates=candidates, + query=ctx.query, + ) + + top_k = self._config.top_k + final_candidates = candidates[:top_k] + + hits = [ + RetrievalHit( + text=c.text, + score=c.score, + source=self.name, + metadata=c.metadata, + ) + for c in final_candidates + if c.score >= self._config.score_threshold + ] + + latency_ms = (time.time() - start_time) * 1000 + + logger.info( + f"[EnhancedPipeline] Retrieval completed: hits={len(hits)}, " + f"latency_ms={latency_ms:.2f}" + ) + + result = RetrievalResult( + hits=hits, + diagnostics={ + "total_candidates": len(candidates), + "after_rerank": self._reranker_config.enabled and ctx.use_reranker, + }, + ) + + return PipelineResult( + retrieval_result=result, + pipeline_name=self.name, + used_reranker=self._reranker_config.enabled and ctx.use_reranker, + metadata_filter_applied=ctx.metadata_filter is not None, + latency_ms=latency_ms, + diagnostics={ + "dense_weight": self._config.hybrid.dense_weight, + "keyword_weight": self._config.hybrid.keyword_weight, + }, + ) + + except Exception as e: + latency_ms = (time.time() - start_time) * 1000 + logger.error(f"[EnhancedPipeline] Retrieval error: {e}", exc_info=True) + return self._create_empty_result(ctx, error=str(e), latency_ms=latency_ms) + + async def _hybrid_retrieve( + self, + tenant_id: str, + query: str, + embedding_result: Any, + metadata_filter: dict[str, Any] | None = None, + kb_ids: list[str] | None = None, + ) -> list[RetrievalCandidate]: + """混合检索:Dense + Keyword + RRF。""" + client = await self._get_client() + top_k = self._config.top_k + expand_factor = self._config.hybrid.keyword_top_k_multiplier + + vector_task = self._dense_search( + client=client, + tenant_id=tenant_id, + embedding=embedding_result.embedding_full, + top_k=top_k * expand_factor, + metadata_filter=metadata_filter, + kb_ids=kb_ids, + ) + + keyword_task = self._keyword_search( + client=client, + tenant_id=tenant_id, + query=query, + top_k=top_k * expand_factor, + metadata_filter=metadata_filter, + kb_ids=kb_ids, + ) if self._config.hybrid.enable_keyword else asyncio.sleep(0, result=[]) + + vector_results, keyword_results = await asyncio.gather( + vector_task, keyword_task, return_exceptions=True + ) + + if isinstance(vector_results, Exception): + logger.warning(f"[EnhancedPipeline] Dense search failed: {vector_results}") + vector_results = [] + + if isinstance(keyword_results, Exception): + logger.warning(f"[EnhancedPipeline] Keyword search failed: {keyword_results}") + keyword_results = [] + + combined = self._rrf_combiner.combine( + vector_results=vector_results, + bm25_results=keyword_results, + vector_weight=self._config.hybrid.dense_weight, + bm25_weight=self._config.hybrid.keyword_weight, + ) + + candidates = [] + for item in combined: + candidates.append(RetrievalCandidate( + id=item.get("id", ""), + text=item.get("payload", {}).get("text", ""), + score=item.get("score", 0.0), + vector_score=item.get("vector_score", 0.0), + keyword_score=item.get("bm25_score", 0.0), + metadata=item.get("payload", {}), + )) + + return candidates + + async def _dense_search( + self, + client: QdrantClient, + tenant_id: str, + embedding: list[float], + top_k: int, + metadata_filter: dict[str, Any] | None = None, + kb_ids: list[str] | None = None, + ) -> list[dict[str, Any]]: + """Dense 向量检索。""" + try: + results = await client.search( + tenant_id=tenant_id, + query_vector=embedding, + limit=top_k, + vector_name="full", + metadata_filter=metadata_filter, + kb_ids=kb_ids, + ) + return results + except Exception as e: + logger.error(f"[EnhancedPipeline] Dense search error: {e}") + return [] + + async def _keyword_search( + self, + client: QdrantClient, + tenant_id: str, + query: str, + top_k: int, + metadata_filter: dict[str, Any] | None = None, + kb_ids: list[str] | None = None, + ) -> list[dict[str, Any]]: + """Keyword 关键词检索。""" + try: + qdrant = await client.get_client() + collection_name = client.get_collection_name(tenant_id) + + query_terms = set(re.findall(r'\w+', query.lower())) + + results = await qdrant.scroll( + collection_name=collection_name, + limit=top_k * 3, + with_payload=True, + ) + + scored_results = [] + for point in results[0]: + text = point.payload.get("text", "").lower() + text_terms = set(re.findall(r'\w+', text)) + overlap = len(query_terms & text_terms) + + if overlap > 0: + score = overlap / (len(query_terms) + len(text_terms) - overlap) + scored_results.append({ + "id": str(point.id), + "score": score, + "payload": point.payload or {}, + }) + + scored_results.sort(key=lambda x: x["score"], reverse=True) + return scored_results[:top_k] + + except Exception as e: + logger.debug(f"[EnhancedPipeline] Keyword search failed: {e}") + return [] + + async def _rerank( + self, + candidates: list[RetrievalCandidate], + query: str, + ) -> list[RetrievalCandidate]: + """可选重排。""" + if not candidates: + return candidates + + try: + provider = await self._get_embedding_provider() + query_embedding = await provider.embed_query(query) + + reranked = [] + for candidate in candidates: + candidate_text = candidate.text[:500] + if candidate_text: + candidate_embedding = await provider.embed(candidate_text) + similarity = self._cosine_similarity( + query_embedding.embedding_full, + candidate_embedding, + ) + candidate.score = similarity + + if candidate.score >= self._reranker_config.min_score_threshold: + reranked.append(candidate) + + reranked.sort(key=lambda x: x.score, reverse=True) + return reranked[:self._reranker_config.top_k_after_rerank] + + except Exception as e: + logger.warning(f"[EnhancedPipeline] Rerank failed: {e}") + return candidates + + def _cosine_similarity(self, vec1: list[float], vec2: list[float]) -> float: + """计算余弦相似度。""" + a = np.array(vec1) + b = np.array(vec2) + return float(np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))) + + async def health_check(self) -> bool: + """健康检查。""" + try: + client = await self._get_client() + qdrant = await client.get_client() + await qdrant.get_collections() + return True + except Exception as e: + logger.error(f"[EnhancedPipeline] Health check failed: {e}") + return False + + +_enhanced_pipeline: EnhancedPipeline | None = None + + +async def get_enhanced_pipeline() -> EnhancedPipeline: + """获取 EnhancedPipeline 单例。""" + global _enhanced_pipeline + if _enhanced_pipeline is None: + _enhanced_pipeline = EnhancedPipeline() + return _enhanced_pipeline diff --git a/ai-service/app/services/retrieval/strategy/metadata_inference.py b/ai-service/app/services/retrieval/strategy/metadata_inference.py new file mode 100644 index 0000000..82dc6a6 --- /dev/null +++ b/ai-service/app/services/retrieval/strategy/metadata_inference.py @@ -0,0 +1,136 @@ +""" +Metadata Inference Service. +[AC-AISVC-RES-04] 元数据推断统一入口。 +""" + +import logging +from dataclasses import dataclass, field +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.services.mid.metadata_filter_builder import ( + FilterBuildResult, + MetadataFilterBuilder, +) +from app.services.retrieval.strategy.config import ( + FilterMode, + MetadataInferenceConfig, +) +from app.services.retrieval.strategy.pipeline_base import MetadataFilterResult + +logger = logging.getLogger(__name__) + + +@dataclass +class InferenceContext: + """元数据推断上下文。""" + tenant_id: str + query: str + session_id: str | None = None + user_id: str | None = None + channel_type: str | None = None + existing_context: dict[str, Any] = field(default_factory=dict) + slot_state: Any = None + + +@dataclass +class InferenceResult: + """元数据推断结果。""" + filter_result: MetadataFilterResult + inferred_fields: dict[str, Any] = field(default_factory=dict) + confidence_scores: dict[str, float] = field(default_factory=dict) + overall_confidence: float | None = None + inference_source: str = "unknown" + + +class MetadataInferenceService: + """ + 元数据推断统一入口。【AC-AISVC-RES-04】 + + 职责: + 1. 统一的元数据推断入口(策略无关) + 2. 根据置信度决定 hard/soft filter 模式 + 3. 与现有 MetadataFilterBuilder 保持一致 + """ + + def __init__( + self, + session: AsyncSession, + config: MetadataInferenceConfig | None = None, + ): + self._session = session + self._config = config or MetadataInferenceConfig() + self._filter_builder: MetadataFilterBuilder | None = None + + async def infer(self, ctx: InferenceContext) -> InferenceResult: + """执行元数据推断。【AC-AISVC-RES-04】""" + logger.info( + f"[MetadataInference] Starting inference: tenant={ctx.tenant_id}, " + f"query={ctx.query[:50]}..." + ) + + if self._filter_builder is None: + self._filter_builder = MetadataFilterBuilder(self._session) + + effective_context = dict(ctx.existing_context) + + if ctx.slot_state: + effective_context = await self._merge_slot_state( + effective_context, ctx.slot_state + ) + + build_result = await self._filter_builder.build_filter( + tenant_id=ctx.tenant_id, + context=effective_context, + ) + + confidence = self._calculate_confidence(build_result, effective_context) + filter_mode = self._config.determine_filter_mode(confidence) + + filter_result = MetadataFilterResult( + filter_dict=build_result.applied_filter, + filter_mode=filter_mode, + confidence=confidence, + missing_required_slots=build_result.missing_required_slots, + debug_info=build_result.debug_info, + ) + + logger.info( + f"[MetadataInference] Inference completed: filter_mode={filter_mode.value}, " + f"confidence={confidence}" + ) + + return InferenceResult( + filter_result=filter_result, + inferred_fields=build_result.applied_filter, + overall_confidence=confidence, + inference_source="metadata_filter_builder", + ) + + async def _merge_slot_state( + self, context: dict[str, Any], slot_state: Any + ) -> dict[str, Any]: + """合并槽位状态到上下文。""" + if hasattr(slot_state, 'filled_slots'): + for slot_key, slot_value in slot_state.filled_slots.items(): + if slot_key not in context: + context[slot_key] = slot_value + return context + + def _calculate_confidence( + self, build_result: FilterBuildResult, context: dict[str, Any] + ) -> float | None: + """计算推断置信度。""" + if build_result.missing_required_slots: + return 0.3 + if not build_result.applied_filter: + return None + if not context: + return 0.5 + applied_ratio = len(build_result.applied_filter) / max(len(context), 1) + if applied_ratio >= 0.8: + return 0.9 + elif applied_ratio >= 0.5: + return 0.7 + return 0.5 diff --git a/ai-service/app/services/retrieval/strategy/mode_router.py b/ai-service/app/services/retrieval/strategy/mode_router.py new file mode 100644 index 0000000..b59e606 --- /dev/null +++ b/ai-service/app/services/retrieval/strategy/mode_router.py @@ -0,0 +1,118 @@ +""" +Mode Router. +[AC-AISVC-RES-09~15] 模式路由器。 +""" + +import logging +from dataclasses import dataclass + +from app.services.retrieval.strategy.config import ModeRouterConfig, RuntimeMode +from app.services.retrieval.strategy.pipeline_base import PipelineResult + +logger = logging.getLogger(__name__) + + +@dataclass +class ModeDecision: + """模式决策结果。""" + mode: RuntimeMode + reason: str + confidence: float | None = None + complexity_score: float | None = None + + +class ModeRouter: + """ + 模式路由器。【AC-AISVC-RES-09~15】 + + 职责: + 1. 根据 rag_runtime_mode 选择 direct/react/auto 模式 + 2. auto 模式下根据复杂度与置信度自动选择路由 + 3. direct 低置信度时触发 react 回退 + """ + + def __init__(self, config: ModeRouterConfig | None = None): + self._config = config or ModeRouterConfig() + + def decide( + self, + query: str, + confidence: float | None = None, + complexity_score: float | None = None, + ) -> ModeDecision: + """决定使用哪种模式。【AC-AISVC-RES-09~13】""" + if self._config.runtime_mode == RuntimeMode.REACT: + return ModeDecision(mode=RuntimeMode.REACT, reason="runtime_mode=react") + + if self._config.runtime_mode == RuntimeMode.DIRECT: + return ModeDecision(mode=RuntimeMode.DIRECT, reason="runtime_mode=direct") + + calculated_complexity = complexity_score or self._calculate_complexity(query) + + if self._should_use_direct(query, confidence, calculated_complexity): + return ModeDecision( + mode=RuntimeMode.DIRECT, + reason="auto: short_query_high_confidence", + confidence=confidence, + complexity_score=calculated_complexity, + ) + + return ModeDecision( + mode=RuntimeMode.REACT, + reason="auto: complex_or_low_confidence", + confidence=confidence, + complexity_score=calculated_complexity, + ) + + def should_fallback_to_react(self, direct_result: PipelineResult) -> bool: + """判断是否应该从 direct 回退到 react。【AC-AISVC-RES-14】""" + if not self._config.direct_fallback_on_low_confidence: + return False + if direct_result.is_empty: + return True + max_score = direct_result.retrieval_result.max_score + if max_score < 0.3: + return True + if direct_result.retrieval_result.hit_count < 2: + return True + return False + + def _should_use_direct( + self, query: str, confidence: float | None, complexity_score: float + ) -> bool: + if len(query) <= self._config.short_query_threshold: + if confidence and confidence >= self._config.react_trigger_confidence_threshold: + return True + if confidence and confidence < self._config.react_trigger_confidence_threshold: + return False + if complexity_score >= self._config.react_trigger_complexity_score: + return False + return True + + def _calculate_complexity(self, query: str) -> float: + score = 0.0 + if len(query) > 50: + score += 0.2 + if len(query) > 100: + score += 0.2 + condition_words = ["和", "或", "但是", "如果", "同时", "并且", "或者", "以及"] + for word in condition_words: + if word in query: + score += 0.1 + return min(score, 1.0) + + def get_config(self) -> ModeRouterConfig: + return self._config + + def update_config(self, config: ModeRouterConfig) -> None: + self._config = config + + +_mode_router: ModeRouter | None = None + + +def get_mode_router() -> ModeRouter: + global _mode_router + if _mode_router is None: + _mode_router = ModeRouter() + return _mode_router diff --git a/ai-service/app/services/retrieval/strategy/pipeline_base.py b/ai-service/app/services/retrieval/strategy/pipeline_base.py new file mode 100644 index 0000000..2cf40a9 --- /dev/null +++ b/ai-service/app/services/retrieval/strategy/pipeline_base.py @@ -0,0 +1,116 @@ +""" +Pipeline Base Classes. +[AC-AISVC-RES-01~15] Pipeline 抽象基类。 +""" + +import logging +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import Any + +from app.services.retrieval.base import RetrievalContext, RetrievalHit, RetrievalResult +from app.services.retrieval.strategy.config import FilterMode + +logger = logging.getLogger(__name__) + + +@dataclass +class MetadataFilterResult: + """元数据过滤结果。""" + filter_dict: dict[str, Any] = field(default_factory=dict) + filter_mode: FilterMode = FilterMode.NONE + confidence: float | None = None + missing_required_slots: list[dict[str, str]] = field(default_factory=list) + debug_info: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class PipelineContext: + """Pipeline 执行上下文。""" + retrieval_ctx: RetrievalContext + metadata_filter: MetadataFilterResult | None = None + use_reranker: bool = False + use_react: bool = False + react_iteration: int = 0 + previous_results: list[RetrievalHit] = field(default_factory=list) + extra: dict[str, Any] = field(default_factory=dict) + + @property + def tenant_id(self) -> str: + return self.retrieval_ctx.tenant_id + + @property + def query(self) -> str: + return self.retrieval_ctx.query + + @property + def session_id(self) -> str | None: + return self.retrieval_ctx.session_id + + @property + def kb_ids(self) -> list[str] | None: + return self.retrieval_ctx.kb_ids + + +@dataclass +class PipelineResult: + """Pipeline 执行结果。""" + retrieval_result: RetrievalResult + pipeline_name: str = "" + used_reranker: bool = False + used_react: bool = False + react_iterations: int = 0 + metadata_filter_applied: bool = False + fallback_triggered: bool = False + fallback_reason: str | None = None + latency_ms: float = 0.0 + diagnostics: dict[str, Any] = field(default_factory=dict) + + @property + def hits(self) -> list[RetrievalHit]: + return self.retrieval_result.hits + + @property + def is_empty(self) -> bool: + return self.retrieval_result.is_empty + + +class BasePipeline(ABC): + """Pipeline 抽象基类。【AC-AISVC-RES-01, AC-AISVC-RES-02】""" + + @property + @abstractmethod + def name(self) -> str: + """Pipeline 名称。""" + pass + + @property + @abstractmethod + def description(self) -> str: + """Pipeline 描述。""" + pass + + @abstractmethod + async def retrieve(self, ctx: PipelineContext) -> PipelineResult: + """执行检索。""" + pass + + @abstractmethod + async def health_check(self) -> bool: + """健康检查。""" + pass + + def _create_empty_result( + self, + ctx: PipelineContext, + error: str | None = None, + latency_ms: float = 0.0, + ) -> PipelineResult: + """创建空结果。""" + diagnostics = {"error": error} if error else {} + return PipelineResult( + retrieval_result=RetrievalResult(hits=[], diagnostics=diagnostics), + pipeline_name=self.name, + latency_ms=latency_ms, + diagnostics=diagnostics, + ) diff --git a/ai-service/app/services/retrieval/strategy/retrieval_strategy.py b/ai-service/app/services/retrieval/strategy/retrieval_strategy.py new file mode 100644 index 0000000..fefc554 --- /dev/null +++ b/ai-service/app/services/retrieval/strategy/retrieval_strategy.py @@ -0,0 +1,301 @@ +""" +Retrieval Strategy - Unified Entry Point. +[AC-AISVC-RES-01~15] 检索策略统一入口。 + +整合: +- StrategyRouter: 策略路由(default/enhanced) +- ModeRouter: 模式路由(direct/react/auto) +- MetadataInferenceService: 元数据推断 +- RollbackManager: 回退管理 +""" + +import logging +import time +from dataclasses import dataclass +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.services.retrieval.base import RetrievalContext, RetrievalResult +from app.services.retrieval.strategy.config import ( + RetrievalStrategyConfig, + RuntimeMode, + StrategyType, +) +from app.services.retrieval.strategy.default_pipeline import DefaultPipeline +from app.services.retrieval.strategy.enhanced_pipeline import EnhancedPipeline +from app.services.retrieval.strategy.metadata_inference import ( + InferenceContext, + MetadataInferenceService, +) +from app.services.retrieval.strategy.mode_router import ModeDecision, ModeRouter +from app.services.retrieval.strategy.pipeline_base import ( + MetadataFilterResult, + PipelineContext, + PipelineResult, +) +from app.services.retrieval.strategy.rollback_manager import ( + RollbackManager, + RollbackTrigger, +) +from app.services.retrieval.strategy.strategy_router import ( + RoutingDecision, + StrategyRouter, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class RetrievalStrategyResult: + """检索策略执行结果。""" + + retrieval_result: RetrievalResult + strategy_used: StrategyType + mode_used: RuntimeMode + metadata_filter: MetadataFilterResult | None + latency_ms: float + diagnostics: dict[str, Any] + + +class RetrievalStrategy: + """ + 检索策略统一入口。【AC-AISVC-RES-01~15】 + + 整合所有策略组件: + 1. 元数据推断(MetadataInferenceService) + 2. 策略路由(StrategyRouter) + 3. 模式路由(ModeRouter) + 4. 回退管理(RollbackManager) + + 使用方式: + ```python + strategy = RetrievalStrategy(session) + result = await strategy.retrieve( + tenant_id="tenant_1", + query="用户问题", + context={"user_id": "user_1"}, + ) + ``` + """ + + def __init__( + self, + session: AsyncSession, + config: RetrievalStrategyConfig | None = None, + ): + self._session = session + self._config = config or RetrievalStrategyConfig() + + self._strategy_router = StrategyRouter(self._config) + self._mode_router = ModeRouter(self._config.mode_router) + self._rollback_manager = RollbackManager(self._config) + self._metadata_inference: MetadataInferenceService | None = None + + async def retrieve( + self, + tenant_id: str, + query: str, + context: dict[str, Any] | None = None, + kb_ids: list[str] | None = None, + session_id: str | None = None, + user_id: str | None = None, + use_reranker: bool = False, + use_react: bool = False, + ) -> RetrievalStrategyResult: + """ + 执行检索策略。【AC-AISVC-RES-01~15】 + + 流程: + 1. 元数据推断 + 2. 策略路由 + 3. 模式路由 + 4. 执行检索 + 5. 检查是否需要回退 + + Args: + tenant_id: 租户 ID + query: 查询文本 + context: 上下文信息 + kb_ids: 知识库 ID 列表 + session_id: 会话 ID + user_id: 用户 ID + use_reranker: 是否使用重排 + use_react: 是否使用 ReAct 模式 + + Returns: + RetrievalStrategyResult 包含检索结果和诊断信息 + """ + start_time = time.time() + context = context or {} + + logger.info( + f"[RetrievalStrategy] Starting retrieval: tenant={tenant_id}, " + f"query={query[:50]}..." + ) + + try: + metadata_filter = await self._infer_metadata( + tenant_id=tenant_id, + query=query, + context=context, + session_id=session_id, + user_id=user_id, + ) + + routing_decision = await self._strategy_router.route(tenant_id, user_id) + + mode_decision = self._mode_router.decide( + query=query, + confidence=metadata_filter.confidence if metadata_filter else None, + ) + + retrieval_ctx = RetrievalContext( + tenant_id=tenant_id, + query=query, + session_id=session_id, + metadata_filter=metadata_filter.filter_dict if metadata_filter else None, + kb_ids=kb_ids, + ) + + pipeline_ctx = PipelineContext( + retrieval_ctx=retrieval_ctx, + metadata_filter=metadata_filter, + use_reranker=use_reranker or self._config.reranker.enabled, + use_react=use_react or mode_decision.mode == RuntimeMode.REACT, + ) + + pipeline_result = await routing_decision.pipeline.retrieve(pipeline_ctx) + + if mode_decision.mode == RuntimeMode.DIRECT and self._mode_router.should_fallback_to_react(pipeline_result): + logger.info("[RetrievalStrategy] Falling back to react mode") + pipeline_ctx.use_react = True + pipeline_result = await routing_decision.pipeline.retrieve(pipeline_ctx) + mode_decision = ModeDecision( + mode=RuntimeMode.REACT, + reason="fallback_from_direct", + ) + + latency_ms = (time.time() - start_time) * 1000 + + self._check_performance(latency_ms, tenant_id) + + logger.info( + f"[RetrievalStrategy] Retrieval completed: strategy={routing_decision.strategy.value}, " + f"mode={mode_decision.mode.value}, hits={len(pipeline_result.hits)}, " + f"latency_ms={latency_ms:.2f}" + ) + + return RetrievalStrategyResult( + retrieval_result=pipeline_result.retrieval_result, + strategy_used=routing_decision.strategy, + mode_used=mode_decision.mode, + metadata_filter=metadata_filter, + latency_ms=latency_ms, + diagnostics={ + "routing_reason": routing_decision.reason, + "mode_reason": mode_decision.reason, + "grayscale_hit": routing_decision.grayscale_hit, + **pipeline_result.diagnostics, + }, + ) + + except Exception as e: + latency_ms = (time.time() - start_time) * 1000 + logger.error( + f"[RetrievalStrategy] Retrieval error: {e}", + exc_info=True + ) + + self._rollback_manager.rollback( + trigger=RollbackTrigger.ERROR, + reason=str(e), + tenant_id=tenant_id, + ) + + return RetrievalStrategyResult( + retrieval_result=RetrievalResult( + hits=[], + diagnostics={"error": str(e)}, + ), + strategy_used=StrategyType.DEFAULT, + mode_used=RuntimeMode.DIRECT, + metadata_filter=None, + latency_ms=latency_ms, + diagnostics={"error": str(e)}, + ) + + async def _infer_metadata( + self, + tenant_id: str, + query: str, + context: dict[str, Any], + session_id: str | None = None, + user_id: str | None = None, + ) -> MetadataFilterResult | None: + """执行元数据推断。""" + try: + if self._metadata_inference is None: + self._metadata_inference = MetadataInferenceService( + self._session, + self._config.metadata_inference, + ) + + inference_ctx = InferenceContext( + tenant_id=tenant_id, + query=query, + session_id=session_id, + user_id=user_id, + existing_context=context, + ) + + result = await self._metadata_inference.infer(inference_ctx) + return result.filter_result + + except Exception as e: + logger.warning(f"[RetrievalStrategy] Metadata inference failed: {e}") + return None + + def _check_performance(self, latency_ms: float, tenant_id: str | None) -> None: + """检查性能指标,必要时触发回退。""" + self._rollback_manager.check_and_rollback( + metrics={"latency_ms": latency_ms}, + tenant_id=tenant_id, + ) + + def get_config(self) -> RetrievalStrategyConfig: + """获取当前配置。""" + return self._config + + def update_config(self, config: RetrievalStrategyConfig) -> None: + """更新配置。""" + self._config = config + self._strategy_router.update_config(config) + self._mode_router.update_config(config.mode_router) + self._rollback_manager.update_config(config) + + async def health_check(self) -> dict[str, bool]: + """健康检查。""" + results = {} + + try: + default_pipeline = await self._strategy_router._get_default_pipeline() + results["default_pipeline"] = await default_pipeline.health_check() + except Exception: + results["default_pipeline"] = False + + try: + enhanced_pipeline = await self._strategy_router._get_enhanced_pipeline() + results["enhanced_pipeline"] = await enhanced_pipeline.health_check() + except Exception: + results["enhanced_pipeline"] = False + + return results + + +async def create_retrieval_strategy( + session: AsyncSession, + config: RetrievalStrategyConfig | None = None, +) -> RetrievalStrategy: + """创建 RetrievalStrategy 实例。""" + return RetrievalStrategy(session, config) diff --git a/ai-service/app/services/retrieval/strategy/rollback_manager.py b/ai-service/app/services/retrieval/strategy/rollback_manager.py new file mode 100644 index 0000000..6c49e58 --- /dev/null +++ b/ai-service/app/services/retrieval/strategy/rollback_manager.py @@ -0,0 +1,192 @@ +""" +Rollback Manager. +[AC-AISVC-RES-07] 策略回退与审计管理器。 +""" + +import logging +from dataclasses import dataclass, field +from datetime import datetime +from enum import Enum +from typing import Any + +from app.services.retrieval.strategy.config import RetrievalStrategyConfig, StrategyType + +logger = logging.getLogger(__name__) + + +class RollbackTrigger(str, Enum): + """回退触发原因。""" + MANUAL = "manual" + ERROR = "error" + PERFORMANCE = "performance" + TIMEOUT = "timeout" + + +@dataclass +class AuditLog: + """审计日志记录。""" + timestamp: str + action: str + from_strategy: str + to_strategy: str + trigger: str + reason: str + tenant_id: str | None = None + details: dict[str, Any] = field(default_factory=dict) + + +@dataclass +class RollbackResult: + """回退结果。""" + success: bool + previous_strategy: StrategyType + current_strategy: StrategyType + trigger: RollbackTrigger + reason: str + audit_log: AuditLog | None = None + + +class RollbackManager: + """ + 策略回退管理器。【AC-AISVC-RES-07】 + + 职责: + 1. 策略异常时回退到默认策略 + 2. 记录审计日志 + 3. 支持手动触发回退 + """ + + def __init__( + self, + config: RetrievalStrategyConfig | None = None, + max_audit_logs: int = 1000, + ): + self._config = config or RetrievalStrategyConfig() + self._max_audit_logs = max_audit_logs + self._audit_logs: list[AuditLog] = [] + self._previous_strategy: StrategyType = StrategyType.DEFAULT + + def rollback( + self, + trigger: RollbackTrigger, + reason: str, + tenant_id: str | None = None, + details: dict[str, Any] | None = None, + ) -> RollbackResult: + """执行策略回退。【AC-AISVC-RES-07】""" + previous = self._config.active_strategy + current = StrategyType.DEFAULT + + if previous == StrategyType.DEFAULT: + return RollbackResult( + success=False, + previous_strategy=previous, + current_strategy=current, + trigger=trigger, + reason="Already on default strategy", + ) + + self._previous_strategy = previous + self._config.active_strategy = current + + audit_log = AuditLog( + timestamp=datetime.utcnow().isoformat(), + action="rollback", + from_strategy=previous.value, + to_strategy=current.value, + trigger=trigger.value, + reason=reason, + tenant_id=tenant_id, + details=details or {}, + ) + + self._add_audit_log(audit_log) + + logger.info( + f"[RollbackManager] Rollback executed: from={previous.value}, " + f"to={current.value}, trigger={trigger.value}" + ) + + return RollbackResult( + success=True, + previous_strategy=previous, + current_strategy=current, + trigger=trigger, + reason=reason, + audit_log=audit_log, + ) + + def check_and_rollback( + self, metrics: dict[str, float], tenant_id: str | None = None + ) -> RollbackResult | None: + """检查性能指标并自动回退。【AC-AISVC-RES-08】""" + thresholds = self._config.performance_thresholds + + latency = metrics.get("latency_ms", 0) + if latency > thresholds.get("max_latency_ms", 2000): + return self.rollback( + trigger=RollbackTrigger.PERFORMANCE, + reason=f"Latency {latency}ms exceeds threshold", + tenant_id=tenant_id, + ) + + error_rate = metrics.get("error_rate", 0) + if error_rate > thresholds.get("max_error_rate", 0.05): + return self.rollback( + trigger=RollbackTrigger.ERROR, + reason=f"Error rate {error_rate} exceeds threshold", + tenant_id=tenant_id, + ) + + return None + + def _add_audit_log(self, log: AuditLog) -> None: + self._audit_logs.append(log) + if len(self._audit_logs) > self._max_audit_logs: + self._audit_logs = self._audit_logs[-self._max_audit_logs:] + + def record_audit( + self, + action: str, + details: dict[str, Any], + tenant_id: str | None = None, + ) -> AuditLog: + """记录审计日志。""" + audit_log = AuditLog( + timestamp=datetime.utcnow().isoformat(), + action=action, + from_strategy=self._config.active_strategy.value, + to_strategy=self._config.active_strategy.value, + trigger="n/a", + reason=details.get("reason", ""), + tenant_id=tenant_id, + details=details, + ) + + self._add_audit_log(audit_log) + + logger.info( + f"[RollbackManager] Audit recorded: action={action}, " + f"strategy={self._config.active_strategy.value}" + ) + + return audit_log + + def get_audit_logs(self, limit: int = 100) -> list[AuditLog]: + return self._audit_logs[-limit:] + + def get_config(self) -> RetrievalStrategyConfig: + return self._config + + def update_config(self, config: RetrievalStrategyConfig) -> None: + self._config = config + + +_rollback_manager: RollbackManager | None = None + + +def get_rollback_manager() -> RollbackManager: + global _rollback_manager + if _rollback_manager is None: + _rollback_manager = RollbackManager() + return _rollback_manager diff --git a/ai-service/app/services/retrieval/strategy/strategy_router.py b/ai-service/app/services/retrieval/strategy/strategy_router.py new file mode 100644 index 0000000..6eecfe6 --- /dev/null +++ b/ai-service/app/services/retrieval/strategy/strategy_router.py @@ -0,0 +1,109 @@ +""" +Strategy Router. +[AC-AISVC-RES-01~03] 策略路由器。 +""" + +import logging +from dataclasses import dataclass +from typing import Any + +from app.services.retrieval.strategy.config import ( + RetrievalStrategyConfig, + StrategyType, +) +from app.services.retrieval.strategy.default_pipeline import DefaultPipeline, get_default_pipeline +from app.services.retrieval.strategy.enhanced_pipeline import EnhancedPipeline, get_enhanced_pipeline +from app.services.retrieval.strategy.pipeline_base import BasePipeline + +logger = logging.getLogger(__name__) + + +@dataclass +class RoutingDecision: + """路由决策结果。""" + strategy: StrategyType + pipeline: BasePipeline + reason: str + grayscale_hit: bool = False + + +class StrategyRouter: + """ + 策略路由器。【AC-AISVC-RES-01~03】 + + 职责: + 1. 根据配置选择默认策略或增强策略 + 2. 支持灰度发布(percentage/allowlist) + 3. 不影响正在运行的默认策略请求 + """ + + def __init__( + self, + config: RetrievalStrategyConfig | None = None, + default_pipeline: DefaultPipeline | None = None, + enhanced_pipeline: EnhancedPipeline | None = None, + ): + self._config = config or RetrievalStrategyConfig() + self._default_pipeline = default_pipeline + self._enhanced_pipeline = enhanced_pipeline + + async def route( + self, tenant_id: str, user_id: str | None = None + ) -> RoutingDecision: + """路由到合适的策略。【AC-AISVC-RES-01~03】""" + if self._config.active_strategy == StrategyType.ENHANCED: + pipeline = await self._get_enhanced_pipeline() + return RoutingDecision( + strategy=StrategyType.ENHANCED, + pipeline=pipeline, + reason="active_strategy=enhanced", + ) + + if self._config.grayscale.should_use_enhanced(tenant_id, user_id): + pipeline = await self._get_enhanced_pipeline() + return RoutingDecision( + strategy=StrategyType.ENHANCED, + pipeline=pipeline, + reason="grayscale_hit", + grayscale_hit=True, + ) + + pipeline = await self._get_default_pipeline() + return RoutingDecision( + strategy=StrategyType.DEFAULT, + pipeline=pipeline, + reason="default_strategy", + ) + + async def _get_default_pipeline(self) -> DefaultPipeline: + if self._default_pipeline is None: + self._default_pipeline = await get_default_pipeline() + return self._default_pipeline + + async def _get_enhanced_pipeline(self) -> EnhancedPipeline: + if self._enhanced_pipeline is None: + self._enhanced_pipeline = await get_enhanced_pipeline() + return self._enhanced_pipeline + + def get_config(self) -> RetrievalStrategyConfig: + return self._config + + def update_config(self, config: RetrievalStrategyConfig) -> None: + self._config = config + logger.info(f"[StrategyRouter] Config updated: strategy={config.active_strategy.value}") + + +_strategy_router: StrategyRouter | None = None + + +def get_strategy_router() -> StrategyRouter: + global _strategy_router + if _strategy_router is None: + _strategy_router = StrategyRouter() + return _strategy_router + + +def set_strategy_router(router: StrategyRouter) -> None: + """Set the global strategy router instance.""" + global _strategy_router + _strategy_router = router diff --git a/ai-service/tests/test_retrieval_strategy_v2.py b/ai-service/tests/test_retrieval_strategy_v2.py new file mode 100644 index 0000000..eab2855 --- /dev/null +++ b/ai-service/tests/test_retrieval_strategy_v2.py @@ -0,0 +1,645 @@ +""" +Unit tests for Retrieval Strategy Module. +[AC-AISVC-RES-01~15] Tests for strategy config, pipelines, routers, and rollback. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from dataclasses import asdict + +from app.services.retrieval.strategy.config import ( + FilterMode, + GrayscaleConfig, + HybridRetrievalConfig, + MetadataInferenceConfig, + ModeRouterConfig, + PipelineConfig, + RerankerConfig, + RetrievalStrategyConfig, + RuntimeMode, + StrategyType, + get_strategy_config, + set_strategy_config, +) +from app.services.retrieval.strategy.pipeline_base import ( + BasePipeline, + MetadataFilterResult, + PipelineContext, + PipelineResult, +) +from app.services.retrieval.strategy.default_pipeline import DefaultPipeline +from app.services.retrieval.strategy.enhanced_pipeline import EnhancedPipeline +from app.services.retrieval.strategy.strategy_router import ( + RoutingDecision, + StrategyRouter, + get_strategy_router, +) +from app.services.retrieval.strategy.mode_router import ( + ModeDecision, + ModeRouter, + get_mode_router, +) +from app.services.retrieval.strategy.rollback_manager import ( + AuditLog, + RollbackManager, + RollbackResult, + RollbackTrigger, + get_rollback_manager, +) +from app.services.retrieval.base import RetrievalContext, RetrievalHit, RetrievalResult + + +class TestStrategyConfig: + """[AC-AISVC-RES-01~15] Tests for strategy configuration models.""" + + def test_strategy_type_enum(self): + """[AC-AISVC-RES-01] Strategy type should have default and enhanced values.""" + assert StrategyType.DEFAULT.value == "default" + assert StrategyType.ENHANCED.value == "enhanced" + + def test_runtime_mode_enum(self): + """[AC-AISVC-RES-09] Runtime mode should have direct, react, and auto values.""" + assert RuntimeMode.DIRECT.value == "direct" + assert RuntimeMode.REACT.value == "react" + assert RuntimeMode.AUTO.value == "auto" + + def test_filter_mode_enum(self): + """[AC-AISVC-RES-04] Filter mode should have hard, soft, and none values.""" + assert FilterMode.HARD.value == "hard" + assert FilterMode.SOFT.value == "soft" + assert FilterMode.NONE.value == "none" + + def test_grayscale_config_default(self): + """[AC-AISVC-RES-03] Default grayscale config should be disabled.""" + config = GrayscaleConfig() + assert config.enabled is False + assert config.percentage == 0.0 + assert config.allowlist == [] + + def test_grayscale_config_should_use_enhanced_disabled(self): + """[AC-AISVC-RES-03] Should not use enhanced when grayscale disabled.""" + config = GrayscaleConfig(enabled=False, percentage=50.0) + assert config.should_use_enhanced("tenant_a") is False + + def test_grayscale_config_should_use_enhanced_allowlist(self): + """[AC-AISVC-RES-03] Should use enhanced for tenants in allowlist.""" + config = GrayscaleConfig(enabled=True, allowlist=["tenant_a", "tenant_b"]) + assert config.should_use_enhanced("tenant_a") is True + assert config.should_use_enhanced("tenant_b") is True + assert config.should_use_enhanced("tenant_c") is False + + def test_grayscale_config_should_use_enhanced_percentage(self): + """[AC-AISVC-RES-03] Should use enhanced based on percentage.""" + config = GrayscaleConfig(enabled=True, percentage=100.0) + assert config.should_use_enhanced("any_tenant") is True + + config = GrayscaleConfig(enabled=True, percentage=0.0) + assert config.should_use_enhanced("any_tenant") is False + + def test_reranker_config_default(self): + """[AC-AISVC-RES-08] Default reranker config should be disabled.""" + config = RerankerConfig() + assert config.enabled is False + assert config.model == "cross-encoder" + assert config.top_k_after_rerank == 5 + + def test_mode_router_config_default(self): + """[AC-AISVC-RES-09] Default mode router config should be direct.""" + config = ModeRouterConfig() + assert config.runtime_mode == RuntimeMode.DIRECT + assert config.react_trigger_confidence_threshold == 0.6 + assert config.react_max_steps == 5 + + def test_mode_router_config_should_use_react_always(self): + """[AC-AISVC-RES-10] React mode should always use react.""" + config = ModeRouterConfig(runtime_mode=RuntimeMode.REACT) + assert config.should_use_react("any query") is True + + def test_mode_router_config_should_use_react_never(self): + """[AC-AISVC-RES-09] Direct mode should never use react.""" + config = ModeRouterConfig(runtime_mode=RuntimeMode.DIRECT) + assert config.should_use_react("any query") is False + + def test_mode_router_config_auto_short_query_high_confidence(self): + """[AC-AISVC-RES-12] Auto mode with short query and high confidence should use direct.""" + config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO) + assert config.should_use_react("短问题", confidence=0.8) is False + + def test_mode_router_config_auto_low_confidence(self): + """[AC-AISVC-RES-13] Auto mode with low confidence should use react.""" + config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO) + assert config.should_use_react("any query", confidence=0.3) is True + + def test_metadata_inference_config_determine_filter_mode(self): + """[AC-AISVC-RES-04] Should determine filter mode based on confidence.""" + config = MetadataInferenceConfig() + + assert config.determine_filter_mode(0.9) == FilterMode.HARD + assert config.determine_filter_mode(0.6) == FilterMode.SOFT + assert config.determine_filter_mode(0.3) == FilterMode.NONE + assert config.determine_filter_mode(None) == FilterMode.NONE + + def test_pipeline_config_default(self): + """[AC-AISVC-RES-01] Default pipeline config should have sensible defaults.""" + config = PipelineConfig() + assert config.top_k == 5 + assert config.score_threshold == 0.01 + assert config.two_stage_enabled is True + + def test_retrieval_strategy_config_default(self): + """[AC-AISVC-RES-01] Default strategy config should use default strategy.""" + config = RetrievalStrategyConfig() + assert config.active_strategy == StrategyType.DEFAULT + assert config.grayscale.enabled is False + assert config.mode_router.runtime_mode == RuntimeMode.DIRECT + + def test_retrieval_strategy_config_is_enhanced_enabled(self): + """[AC-AISVC-RES-02] Should check if enhanced is enabled.""" + config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED) + assert config.is_enhanced_enabled("tenant_a") is True + + config = RetrievalStrategyConfig( + active_strategy=StrategyType.DEFAULT, + grayscale=GrayscaleConfig(enabled=True, allowlist=["tenant_a"]), + ) + assert config.is_enhanced_enabled("tenant_a") is True + assert config.is_enhanced_enabled("tenant_b") is False + + def test_retrieval_strategy_config_to_dict(self): + """[AC-AISVC-RES-01] Should convert config to dictionary.""" + config = RetrievalStrategyConfig() + d = config.to_dict() + + assert d["active_strategy"] == "default" + assert "grayscale" in d + assert "pipeline" in d + assert "reranker" in d + assert "mode_router" in d + + def test_global_config_functions(self): + """[AC-AISVC-RES-01] Should get and set global config.""" + config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED) + set_strategy_config(config) + + retrieved = get_strategy_config() + assert retrieved.active_strategy == StrategyType.ENHANCED + + set_strategy_config(RetrievalStrategyConfig()) + + +class TestPipelineBase: + """[AC-AISVC-RES-01~02] Tests for pipeline base classes.""" + + def test_metadata_filter_result_default(self): + """[AC-AISVC-RES-04] Default metadata filter result should be empty.""" + result = MetadataFilterResult() + assert result.filter_dict == {} + assert result.filter_mode == FilterMode.NONE + assert result.confidence is None + + def test_pipeline_context_properties(self): + """[AC-AISVC-RES-01] Pipeline context should expose retrieval context properties.""" + retrieval_ctx = RetrievalContext( + tenant_id="tenant_1", + query="test query", + session_id="session_1", + kb_ids=["kb_1"], + ) + pipeline_ctx = PipelineContext(retrieval_ctx=retrieval_ctx) + + assert pipeline_ctx.tenant_id == "tenant_1" + assert pipeline_ctx.query == "test query" + assert pipeline_ctx.session_id == "session_1" + assert pipeline_ctx.kb_ids == ["kb_1"] + + def test_pipeline_result_properties(self): + """[AC-AISVC-RES-01] Pipeline result should expose retrieval result properties.""" + hits = [ + RetrievalHit(text="hit 1", score=0.9, source="test", metadata={}), + RetrievalHit(text="hit 2", score=0.8, source="test", metadata={}), + ] + retrieval_result = RetrievalResult(hits=hits) + pipeline_result = PipelineResult( + retrieval_result=retrieval_result, + pipeline_name="test_pipeline", + ) + + assert pipeline_result.hits == hits + assert pipeline_result.is_empty is False + assert pipeline_result.pipeline_name == "test_pipeline" + + def test_pipeline_result_is_empty(self): + """[AC-AISVC-RES-01] Pipeline result should detect empty results.""" + pipeline_result = PipelineResult( + retrieval_result=RetrievalResult(hits=[]), + ) + assert pipeline_result.is_empty is True + + +class TestDefaultPipeline: + """[AC-AISVC-RES-01] Tests for default pipeline.""" + + @pytest.fixture + def mock_retriever(self): + """Create a mock optimized retriever.""" + retriever = AsyncMock() + retriever.retrieve = AsyncMock(return_value=RetrievalResult( + hits=[ + RetrievalHit(text="result 1", score=0.9, source="default", metadata={}), + ], + diagnostics={"test": True}, + )) + retriever.health_check = AsyncMock(return_value=True) + retriever._two_stage_enabled = True + retriever._hybrid_enabled = True + return retriever + + @pytest.fixture + def pipeline(self, mock_retriever): + """Create a default pipeline with mock retriever.""" + return DefaultPipeline(optimized_retriever=mock_retriever) + + def test_pipeline_name(self, pipeline): + """[AC-AISVC-RES-01] Pipeline should have correct name.""" + assert pipeline.name == "default_pipeline" + + def test_pipeline_description(self, pipeline): + """[AC-AISVC-RES-01] Pipeline should have description.""" + assert "默认" in pipeline.description + + @pytest.mark.asyncio + async def test_retrieve(self, pipeline, mock_retriever): + """[AC-AISVC-RES-01] Should retrieve results using optimized retriever.""" + retrieval_ctx = RetrievalContext( + tenant_id="tenant_1", + query="test query", + ) + pipeline_ctx = PipelineContext(retrieval_ctx=retrieval_ctx) + + result = await pipeline.retrieve(pipeline_ctx) + + assert result.pipeline_name == "default_pipeline" + assert len(result.hits) == 1 + assert result.diagnostics["retriever"] == "OptimizedRetriever" + mock_retriever.retrieve.assert_called_once() + + @pytest.mark.asyncio + async def test_retrieve_with_metadata_filter(self, pipeline, mock_retriever): + """[AC-AISVC-RES-04] Should apply metadata filter.""" + retrieval_ctx = RetrievalContext( + tenant_id="tenant_1", + query="test query", + ) + metadata_filter = MetadataFilterResult( + filter_dict={"grade": "初一"}, + filter_mode=FilterMode.HARD, + ) + pipeline_ctx = PipelineContext( + retrieval_ctx=retrieval_ctx, + metadata_filter=metadata_filter, + ) + + result = await pipeline.retrieve(pipeline_ctx) + + assert result.metadata_filter_applied is True + call_args = mock_retriever.retrieve.call_args[0][0] + assert call_args.metadata_filter == {"grade": "初一"} + + @pytest.mark.asyncio + async def test_health_check(self, pipeline, mock_retriever): + """[AC-AISVC-RES-01] Should check health.""" + result = await pipeline.health_check() + assert result is True + mock_retriever.health_check.assert_called_once() + + +class TestEnhancedPipeline: + """[AC-AISVC-RES-02] Tests for enhanced pipeline.""" + + @pytest.fixture + def mock_qdrant_client(self): + """Create a mock Qdrant client.""" + client = AsyncMock() + client.search = AsyncMock(return_value=[ + {"id": "1", "score": 0.9, "payload": {"text": "result 1"}}, + ]) + client.get_client = AsyncMock() + return client + + @pytest.fixture + def mock_embedding_provider(self): + """Create a mock embedding provider.""" + provider = AsyncMock() + provider.embed_query = AsyncMock() + provider.embed_query.return_value = MagicMock( + embedding_full=[0.1] * 768, + ) + provider.embed = AsyncMock(return_value=[0.1] * 768) + return provider + + @pytest.fixture + def pipeline(self, mock_qdrant_client, mock_embedding_provider): + """Create an enhanced pipeline with mocks.""" + pipeline = EnhancedPipeline(qdrant_client=mock_qdrant_client) + pipeline._embedding_provider = mock_embedding_provider + return pipeline + + def test_pipeline_name(self, pipeline): + """[AC-AISVC-RES-02] Pipeline should have correct name.""" + assert pipeline.name == "enhanced_pipeline" + + def test_pipeline_description(self, pipeline): + """[AC-AISVC-RES-02] Pipeline should have description.""" + assert "增强" in pipeline.description + + @pytest.mark.asyncio + async def test_retrieve_basic(self, pipeline): + """[AC-AISVC-RES-02] Should retrieve results using hybrid search.""" + retrieval_ctx = RetrievalContext( + tenant_id="tenant_1", + query="test query", + ) + pipeline_ctx = PipelineContext(retrieval_ctx=retrieval_ctx) + + result = await pipeline.retrieve(pipeline_ctx) + + assert result.pipeline_name == "enhanced_pipeline" + assert result.diagnostics is not None + + +class TestStrategyRouter: + """[AC-AISVC-RES-01~03] Tests for strategy router.""" + + @pytest.fixture + def mock_default_pipeline(self): + """Create a mock default pipeline.""" + pipeline = AsyncMock(spec=DefaultPipeline) + pipeline.name = "default_pipeline" + pipeline.retrieve = AsyncMock(return_value=PipelineResult( + retrieval_result=RetrievalResult(hits=[]), + pipeline_name="default_pipeline", + )) + return pipeline + + @pytest.fixture + def mock_enhanced_pipeline(self): + """Create a mock enhanced pipeline.""" + pipeline = AsyncMock(spec=EnhancedPipeline) + pipeline.name = "enhanced_pipeline" + pipeline.retrieve = AsyncMock(return_value=PipelineResult( + retrieval_result=RetrievalResult(hits=[]), + pipeline_name="enhanced_pipeline", + )) + return pipeline + + @pytest.fixture + def router(self, mock_default_pipeline, mock_enhanced_pipeline): + """Create a strategy router with mock pipelines.""" + config = RetrievalStrategyConfig() + return StrategyRouter( + config=config, + default_pipeline=mock_default_pipeline, + enhanced_pipeline=mock_enhanced_pipeline, + ) + + def test_route_default_strategy(self, router): + """[AC-AISVC-RES-01] Should route to default strategy by default.""" + import asyncio + decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_1")) + + assert decision.strategy == StrategyType.DEFAULT + assert decision.reason == "default_strategy" + + def test_route_enhanced_strategy(self, mock_default_pipeline, mock_enhanced_pipeline): + """[AC-AISVC-RES-02] Should route to enhanced strategy when configured.""" + config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED) + router = StrategyRouter( + config=config, + default_pipeline=mock_default_pipeline, + enhanced_pipeline=mock_enhanced_pipeline, + ) + + import asyncio + decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_1")) + + assert decision.strategy == StrategyType.ENHANCED + assert decision.reason == "active_strategy=enhanced" + + def test_route_grayscale_allowlist(self, mock_default_pipeline, mock_enhanced_pipeline): + """[AC-AISVC-RES-03] Should route to enhanced for allowlist tenants.""" + config = RetrievalStrategyConfig( + active_strategy=StrategyType.DEFAULT, + grayscale=GrayscaleConfig(enabled=True, allowlist=["tenant_a"]), + ) + router = StrategyRouter( + config=config, + default_pipeline=mock_default_pipeline, + enhanced_pipeline=mock_enhanced_pipeline, + ) + + import asyncio + + decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_a")) + assert decision.strategy == StrategyType.ENHANCED + assert decision.grayscale_hit is True + + decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_b")) + assert decision.strategy == StrategyType.DEFAULT + + def test_update_config(self, router): + """[AC-AISVC-RES-02] Should update config.""" + new_config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED) + router.update_config(new_config) + + assert router.get_config().active_strategy == StrategyType.ENHANCED + + +class TestModeRouter: + """[AC-AISVC-RES-09~15] Tests for mode router.""" + + @pytest.fixture + def router(self): + """Create a mode router.""" + return ModeRouter() + + def test_decide_react_mode(self): + """[AC-AISVC-RES-10] Should decide react when configured.""" + config = ModeRouterConfig(runtime_mode=RuntimeMode.REACT) + router = ModeRouter(config) + + decision = router.decide("any query") + + assert decision.mode == RuntimeMode.REACT + assert decision.reason == "runtime_mode=react" + + def test_decide_direct_mode(self, router): + """[AC-AISVC-RES-09] Should decide direct when configured.""" + decision = router.decide("any query") + + assert decision.mode == RuntimeMode.DIRECT + assert decision.reason == "runtime_mode=direct" + + def test_decide_auto_short_query_high_confidence(self): + """[AC-AISVC-RES-12] Auto with short query and high confidence should use direct.""" + config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO) + router = ModeRouter(config) + + decision = router.decide("短问题", confidence=0.8) + + assert decision.mode == RuntimeMode.DIRECT + + def test_decide_auto_low_confidence(self): + """[AC-AISVC-RES-13] Auto with low confidence should use react.""" + config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO) + router = ModeRouter(config) + + decision = router.decide("any query", confidence=0.3) + + assert decision.mode == RuntimeMode.REACT + + def test_should_fallback_to_react_empty_results(self, router): + """[AC-AISVC-RES-14] Should fallback to react on empty results.""" + result = PipelineResult(retrieval_result=RetrievalResult(hits=[])) + + assert router.should_fallback_to_react(result) is True + + def test_should_fallback_to_react_low_score(self, router): + """[AC-AISVC-RES-14] Should fallback to react on low score.""" + result = PipelineResult( + retrieval_result=RetrievalResult( + hits=[RetrievalHit(text="test", score=0.1, source="test", metadata={})], + ), + ) + + assert router.should_fallback_to_react(result) is True + + def test_should_not_fallback_to_react_disabled(self): + """[AC-AISVC-RES-14] Should not fallback when disabled.""" + config = ModeRouterConfig(direct_fallback_on_low_confidence=False) + router = ModeRouter(config) + + result = PipelineResult(retrieval_result=RetrievalResult(hits=[])) + + assert router.should_fallback_to_react(result) is False + + +class TestRollbackManager: + """[AC-AISVC-RES-07] Tests for rollback manager.""" + + @pytest.fixture + def manager(self): + """Create a rollback manager.""" + return RollbackManager() + + def test_rollback_from_enhanced(self, manager): + """[AC-AISVC-RES-07] Should rollback from enhanced to default.""" + config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED) + manager.update_config(config) + + result = manager.rollback( + trigger=RollbackTrigger.MANUAL, + reason="Testing rollback", + ) + + assert result.success is True + assert result.previous_strategy == StrategyType.ENHANCED + assert result.current_strategy == StrategyType.DEFAULT + assert result.audit_log is not None + + def test_rollback_already_default(self, manager): + """[AC-AISVC-RES-07] Should not rollback when already on default.""" + result = manager.rollback( + trigger=RollbackTrigger.MANUAL, + reason="Testing rollback", + ) + + assert result.success is False + assert result.reason == "Already on default strategy" + + def test_check_and_rollback_latency(self, manager): + """[AC-AISVC-RES-08] Should rollback on high latency.""" + config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED) + manager.update_config(config) + + result = manager.check_and_rollback( + metrics={"latency_ms": 3000.0}, + tenant_id="tenant_1", + ) + + assert result is not None + assert result.trigger == RollbackTrigger.PERFORMANCE + + def test_check_and_rollback_error_rate(self, manager): + """[AC-AISVC-RES-08] Should rollback on high error rate.""" + config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED) + manager.update_config(config) + + result = manager.check_and_rollback( + metrics={"error_rate": 0.1}, + tenant_id="tenant_1", + ) + + assert result is not None + assert result.trigger == RollbackTrigger.ERROR + + def test_check_and_rollback_ok(self, manager): + """[AC-AISVC-RES-08] Should not rollback when metrics are ok.""" + config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED) + manager.update_config(config) + + result = manager.check_and_rollback( + metrics={"latency_ms": 100.0, "error_rate": 0.01}, + tenant_id="tenant_1", + ) + + assert result is None + + def test_get_audit_logs(self, manager): + """[AC-AISVC-RES-07] Should get audit logs.""" + config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED) + manager.update_config(config) + manager.rollback(trigger=RollbackTrigger.MANUAL, reason="Test") + + logs = manager.get_audit_logs() + + assert len(logs) == 1 + assert logs[0].action == "rollback" + + def test_record_audit(self, manager): + """[AC-AISVC-RES-07] Should record audit log.""" + log = manager.record_audit( + action="test_action", + details={"reason": "Testing"}, + tenant_id="tenant_1", + ) + + assert log.action == "test_action" + assert log.tenant_id == "tenant_1" + + +class TestSingletonInstances: + """Tests for singleton instance getters.""" + + def test_get_mode_router_singleton(self): + """Should return same mode router instance.""" + from app.services.retrieval.strategy.mode_router import _mode_router + + import app.services.retrieval.strategy.mode_router as module + module._mode_router = None + + router1 = get_mode_router() + router2 = get_mode_router() + + assert router1 is router2 + + def test_get_rollback_manager_singleton(self): + """Should return same rollback manager instance.""" + from app.services.retrieval.strategy.rollback_manager import _rollback_manager + + import app.services.retrieval.strategy.rollback_manager as module + module._rollback_manager = None + + manager1 = get_rollback_manager() + manager2 = get_rollback_manager() + + assert manager1 is manager2 diff --git a/docs/progress/ai-service-AISVC-RES-progress.md b/docs/progress/ai-service-AISVC-RES-progress.md new file mode 100644 index 0000000..f478b29 --- /dev/null +++ b/docs/progress/ai-service-AISVC-RES-progress.md @@ -0,0 +1,167 @@ +--- +context: + module: "ai-service" + feature: "AISVC-RES" + status: "✅已完成" + version: "0.9.0" + active_ac_range: "AC-AISVC-RES-01~15" + +spec_references: + requirements: "spec/ai-service/iterations/v0.9.0-retrieval-embedding-strategy/requirements.md" + design: "spec/ai-service/iterations/v0.9.0-retrieval-embedding-strategy/design.md" + tasks: "spec/ai-service/iterations/v0.9.0-retrieval-embedding-strategy/tasks.md" + openapi_provider: "spec/ai-service/iterations/v0.9.0-retrieval-embedding-strategy/openapi.provider.yaml" + active_version: "0.1.0" + +overall_progress: + - "[x] Phase 1: Schema与数据模型定义 (100%) [API与校验]" + - "[x] Phase 2: 策略服务层实现 (100%) [策略层与配置]" + - "[x] Phase 3: 审计日志与指标埋点 (100%) [观测与灰度验证]" + - "[x] Phase 4: API端点实现 (100%) [API与校验]" + - "[x] Phase 5: 单元测试与验证 (100%) [验收]" + - "[x] Phase 6: 检索策略Pipeline实现 (100%) [检索与嵌入策略]" + +current_phase: + goal: "检索与嵌入策略化改造已完成" + sub_tasks: + - "[x] 创建 Schema 模型 (app/schemas/retrieval_strategy.py)" + - "[x] 创建策略服务层 (app/services/retrieval/strategy_service.py)" + - "[x] 创建审计日志服务 (app/services/retrieval/strategy_audit.py)" + - "[x] 创建指标埋点服务 (app/services/retrieval/strategy_metrics.py)" + - "[x] 创建 API 端点 (app/api/admin/retrieval_strategy.py)" + - "[x] 创建策略配置模型 (app/services/retrieval/strategy/config.py)" + - "[x] 实现 DefaultPipeline(复用现有逻辑)" + - "[x] 实现 EnhancedPipeline(新端到端流程)" + - "[x] 实现元数据推断统一入口 (MetadataInferenceService)" + - "[x] 实现 StrategyRouter 和 ModeRouter" + - "[x] 实现 Dense + Keyword + RRF 组合检索" + - "[x] 实现可选重排与降级开关" + - "[x] 实现 RollbackManager(回退与审计)" + - "[x] 实现策略 API 接口" + - "[x] 创建单元测试 (tests/test_retrieval_strategy_v2.py)" + - "[x] 运行单元测试验证 (51 passed)" + + next_action: + immediate: "任务已完成,可进行集成测试" + details: + file: "ai-service/app/services/retrieval/strategy/__init__.py:1" + action: "模块已完整实现,可通过 API 接口测试策略切换功能" + reference: "http://localhost:8000/docs" + constraints: "新策略可配置启用,不影响默认策略" + +technical_context: + module_structure: | + ai-service/app/ + ├── api/admin/retrieval_strategy.py (API端点) + ├── schemas/retrieval_strategy.py (Schema模型) + └── services/retrieval/ + ├── strategy_service.py (策略服务) + ├── strategy_audit.py (审计日志) + ├── strategy_metrics.py (指标埋点) + └── strategy/ (新增 - 策略模块) + ├── __init__.py (模块导出) + ├── config.py (策略配置模型) + ├── pipeline_base.py (Pipeline基类) + ├── default_pipeline.py (默认策略Pipeline) + ├── enhanced_pipeline.py (增强策略Pipeline) + ├── metadata_inference.py (元数据推断统一入口) + ├── strategy_router.py (策略路由器) + ├── mode_router.py (模式路由器) + └── rollback_manager.py (回退管理器) + └── tests/ + ├── test_retrieval_strategy.py (单元测试 - 原有) + └── test_retrieval_strategy_v2.py (单元测试 - 新增) + key_decisions: + - decision: "使用内存存储策略状态,后续可扩展为持久化" + reason: "快速实现,满足灰度验证需求" + impact: "服务重启后策略状态重置为默认值" + - decision: "审计日志使用结构化日志记录" + reason: "与现有日志体系一致,便于检索" + impact: "需要配置日志聚合系统收集审计日志" + - decision: "API与现有strategy_router.py互补" + reason: "strategy_router.py负责检索路由逻辑,新增的API负责策略管理" + impact: "两者协同工作,API提供管理界面" + - decision: "DefaultPipeline 复用现有 OptimizedRetriever 逻辑" + reason: "保持线上行为不变,最小化改动风险" + impact: "新策略与旧策略完全隔离,可独立灰度" + - decision: "EnhancedPipeline 实现新端到端流程" + reason: "支持 Dense + Keyword + RRF 组合检索,可选重排" + impact: "需要配置启用,不影响默认策略" + - decision: "元数据推断统一入口处理 hard/soft filter" + reason: "新旧策略共享同一推断逻辑,确保一致性" + impact: "置信度高用硬过滤,置信度低用软过滤/加权" + code_snippets: | + # 使用示例 + from app.services.retrieval.strategy import ( + get_strategy_router, + get_mode_router, + get_rollback_manager, + ) + + # 策略路由 + router = get_strategy_router() + decision = await router.route(tenant_id, user_id) + result = await decision.pipeline.retrieve(ctx) + + # 模式路由 + mode_router = get_mode_router() + mode_decision = mode_router.decide(query, confidence=0.8) + + # 回退管理 + rollback = get_rollback_manager() + rollback.rollback(trigger="manual", reason="测试回退") + +session_history: + - session: "Session #1 (2026-03-10)" + completed: + - "创建 Schema 模型 (app/schemas/retrieval_strategy.py)" + - "创建策略服务层 (app/services/retrieval/strategy_service.py)" + - "创建审计日志服务 (app/services/retrieval/strategy_audit.py)" + - "创建指标埋点服务 (app/services/retrieval/strategy_metrics.py)" + - "创建 API 端点 (app/api/admin/retrieval_strategy.py)" + - "更新 admin __init__.py 注册新路由" + - "更新 main.py 注册新路由" + - "创建单元测试 (tests/test_retrieval_strategy.py)" + - "运行单元测试验证 (46 passed)" + changes: + - "新增: ai-service/app/schemas/retrieval_strategy.py" + - "新增: ai-service/app/services/retrieval/strategy_service.py" + - "新增: ai-service/app/services/retrieval/strategy_audit.py" + - "新增: ai-service/app/services/retrieval/strategy_metrics.py" + - "新增: ai-service/app/api/admin/retrieval_strategy.py" + - "修改: ai-service/app/api/admin/__init__.py" + - "修改: ai-service/app/main.py" + - "新增: ai-service/tests/test_retrieval_strategy.py" + status: "✅ 任务完成" + - session: "Session #2 (2026-03-10) - 检索策略Pipeline实现" + completed: + - "实现策略配置模型 (config.py)" + - "实现 Pipeline 基类 (pipeline_base.py)" + - "实现 DefaultPipeline(复用现有逻辑)" + - "实现 EnhancedPipeline(新端到端流程)" + - "实现 MetadataInferenceService" + - "实现 StrategyRouter" + - "实现 ModeRouter" + - "实现 RollbackManager" + - "更新模块导出 (__init__.py)" + - "创建单元测试 (tests/test_retrieval_strategy_v2.py)" + - "运行单元测试验证 (51 passed)" + changes: + - "新增: ai-service/app/services/retrieval/strategy/__init__.py" + - "新增: ai-service/app/services/retrieval/strategy/config.py" + - "新增: ai-service/app/services/retrieval/strategy/pipeline_base.py" + - "新增: ai-service/app/services/retrieval/strategy/default_pipeline.py" + - "新增: ai-service/app/services/retrieval/strategy/enhanced_pipeline.py" + - "新增: ai-service/app/services/retrieval/strategy/metadata_inference.py" + - "新增: ai-service/app/services/retrieval/strategy/strategy_router.py" + - "新增: ai-service/app/services/retrieval/strategy/mode_router.py" + - "新增: ai-service/app/services/retrieval/strategy/rollback_manager.py" + - "新增: ai-service/tests/test_retrieval_strategy_v2.py" + status: "✅ 任务完成" + +startup_guide: + - "Step 1: 读取本进度文档(了解当前位置与下一步)" + - "Step 2: 读取 spec_references 中定义的模块规范(了解业务与接口约束)" + - "Step 3: 通过 API 接口测试策略切换功能" + - "Step 4: 运行单元测试验证: pytest tests/test_retrieval_strategy_v2.py -v" +---