""" Unit tests for Retrieval Strategy Module. [AC-AISVC-RES-01~15] Tests for strategy config, pipelines, routers, and rollback. """ import pytest from unittest.mock import AsyncMock, MagicMock, patch from dataclasses import asdict from app.services.retrieval.strategy.config import ( FilterMode, GrayscaleConfig, HybridRetrievalConfig, MetadataInferenceConfig, ModeRouterConfig, PipelineConfig, RerankerConfig, RetrievalStrategyConfig, RuntimeMode, StrategyType, get_strategy_config, set_strategy_config, ) from app.services.retrieval.strategy.pipeline_base import ( BasePipeline, MetadataFilterResult, PipelineContext, PipelineResult, ) from app.services.retrieval.strategy.default_pipeline import DefaultPipeline from app.services.retrieval.strategy.enhanced_pipeline import EnhancedPipeline from app.services.retrieval.strategy.strategy_router import ( RoutingDecision, StrategyRouter, get_strategy_router, ) from app.services.retrieval.strategy.mode_router import ( ModeDecision, ModeRouter, get_mode_router, ) from app.services.retrieval.strategy.rollback_manager import ( AuditLog, RollbackManager, RollbackResult, RollbackTrigger, get_rollback_manager, ) from app.services.retrieval.base import RetrievalContext, RetrievalHit, RetrievalResult class TestStrategyConfig: """[AC-AISVC-RES-01~15] Tests for strategy configuration models.""" def test_strategy_type_enum(self): """[AC-AISVC-RES-01] Strategy type should have default and enhanced values.""" assert StrategyType.DEFAULT.value == "default" assert StrategyType.ENHANCED.value == "enhanced" def test_runtime_mode_enum(self): """[AC-AISVC-RES-09] Runtime mode should have direct, react, and auto values.""" assert RuntimeMode.DIRECT.value == "direct" assert RuntimeMode.REACT.value == "react" assert RuntimeMode.AUTO.value == "auto" def test_filter_mode_enum(self): """[AC-AISVC-RES-04] Filter mode should have hard, soft, and none values.""" assert FilterMode.HARD.value == "hard" assert FilterMode.SOFT.value == "soft" assert FilterMode.NONE.value == "none" def test_grayscale_config_default(self): """[AC-AISVC-RES-03] Default grayscale config should be disabled.""" config = GrayscaleConfig() assert config.enabled is False assert config.percentage == 0.0 assert config.allowlist == [] def test_grayscale_config_should_use_enhanced_disabled(self): """[AC-AISVC-RES-03] Should not use enhanced when grayscale disabled.""" config = GrayscaleConfig(enabled=False, percentage=50.0) assert config.should_use_enhanced("tenant_a") is False def test_grayscale_config_should_use_enhanced_allowlist(self): """[AC-AISVC-RES-03] Should use enhanced for tenants in allowlist.""" config = GrayscaleConfig(enabled=True, allowlist=["tenant_a", "tenant_b"]) assert config.should_use_enhanced("tenant_a") is True assert config.should_use_enhanced("tenant_b") is True assert config.should_use_enhanced("tenant_c") is False def test_grayscale_config_should_use_enhanced_percentage(self): """[AC-AISVC-RES-03] Should use enhanced based on percentage.""" config = GrayscaleConfig(enabled=True, percentage=100.0) assert config.should_use_enhanced("any_tenant") is True config = GrayscaleConfig(enabled=True, percentage=0.0) assert config.should_use_enhanced("any_tenant") is False def test_reranker_config_default(self): """[AC-AISVC-RES-08] Default reranker config should be disabled.""" config = RerankerConfig() assert config.enabled is False assert config.model == "cross-encoder" assert config.top_k_after_rerank == 5 def test_mode_router_config_default(self): """[AC-AISVC-RES-09] Default mode router config should be direct.""" config = ModeRouterConfig() assert config.runtime_mode == RuntimeMode.DIRECT assert config.react_trigger_confidence_threshold == 0.6 assert config.react_max_steps == 5 def test_mode_router_config_should_use_react_always(self): """[AC-AISVC-RES-10] React mode should always use react.""" config = ModeRouterConfig(runtime_mode=RuntimeMode.REACT) assert config.should_use_react("any query") is True def test_mode_router_config_should_use_react_never(self): """[AC-AISVC-RES-09] Direct mode should never use react.""" config = ModeRouterConfig(runtime_mode=RuntimeMode.DIRECT) assert config.should_use_react("any query") is False def test_mode_router_config_auto_short_query_high_confidence(self): """[AC-AISVC-RES-12] Auto mode with short query and high confidence should use direct.""" config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO) assert config.should_use_react("短问题", confidence=0.8) is False def test_mode_router_config_auto_low_confidence(self): """[AC-AISVC-RES-13] Auto mode with low confidence should use react.""" config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO) assert config.should_use_react("any query", confidence=0.3) is True def test_metadata_inference_config_determine_filter_mode(self): """[AC-AISVC-RES-04] Should determine filter mode based on confidence.""" config = MetadataInferenceConfig() assert config.determine_filter_mode(0.9) == FilterMode.HARD assert config.determine_filter_mode(0.6) == FilterMode.SOFT assert config.determine_filter_mode(0.3) == FilterMode.NONE assert config.determine_filter_mode(None) == FilterMode.NONE def test_pipeline_config_default(self): """[AC-AISVC-RES-01] Default pipeline config should have sensible defaults.""" config = PipelineConfig() assert config.top_k == 5 assert config.score_threshold == 0.01 assert config.two_stage_enabled is True def test_retrieval_strategy_config_default(self): """[AC-AISVC-RES-01] Default strategy config should use default strategy.""" config = RetrievalStrategyConfig() assert config.active_strategy == StrategyType.DEFAULT assert config.grayscale.enabled is False assert config.mode_router.runtime_mode == RuntimeMode.DIRECT def test_retrieval_strategy_config_is_enhanced_enabled(self): """[AC-AISVC-RES-02] Should check if enhanced is enabled.""" config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED) assert config.is_enhanced_enabled("tenant_a") is True config = RetrievalStrategyConfig( active_strategy=StrategyType.DEFAULT, grayscale=GrayscaleConfig(enabled=True, allowlist=["tenant_a"]), ) assert config.is_enhanced_enabled("tenant_a") is True assert config.is_enhanced_enabled("tenant_b") is False def test_retrieval_strategy_config_to_dict(self): """[AC-AISVC-RES-01] Should convert config to dictionary.""" config = RetrievalStrategyConfig() d = config.to_dict() assert d["active_strategy"] == "default" assert "grayscale" in d assert "pipeline" in d assert "reranker" in d assert "mode_router" in d def test_global_config_functions(self): """[AC-AISVC-RES-01] Should get and set global config.""" config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED) set_strategy_config(config) retrieved = get_strategy_config() assert retrieved.active_strategy == StrategyType.ENHANCED set_strategy_config(RetrievalStrategyConfig()) class TestPipelineBase: """[AC-AISVC-RES-01~02] Tests for pipeline base classes.""" def test_metadata_filter_result_default(self): """[AC-AISVC-RES-04] Default metadata filter result should be empty.""" result = MetadataFilterResult() assert result.filter_dict == {} assert result.filter_mode == FilterMode.NONE assert result.confidence is None def test_pipeline_context_properties(self): """[AC-AISVC-RES-01] Pipeline context should expose retrieval context properties.""" retrieval_ctx = RetrievalContext( tenant_id="tenant_1", query="test query", session_id="session_1", kb_ids=["kb_1"], ) pipeline_ctx = PipelineContext(retrieval_ctx=retrieval_ctx) assert pipeline_ctx.tenant_id == "tenant_1" assert pipeline_ctx.query == "test query" assert pipeline_ctx.session_id == "session_1" assert pipeline_ctx.kb_ids == ["kb_1"] def test_pipeline_result_properties(self): """[AC-AISVC-RES-01] Pipeline result should expose retrieval result properties.""" hits = [ RetrievalHit(text="hit 1", score=0.9, source="test", metadata={}), RetrievalHit(text="hit 2", score=0.8, source="test", metadata={}), ] retrieval_result = RetrievalResult(hits=hits) pipeline_result = PipelineResult( retrieval_result=retrieval_result, pipeline_name="test_pipeline", ) assert pipeline_result.hits == hits assert pipeline_result.is_empty is False assert pipeline_result.pipeline_name == "test_pipeline" def test_pipeline_result_is_empty(self): """[AC-AISVC-RES-01] Pipeline result should detect empty results.""" pipeline_result = PipelineResult( retrieval_result=RetrievalResult(hits=[]), ) assert pipeline_result.is_empty is True class TestDefaultPipeline: """[AC-AISVC-RES-01] Tests for default pipeline.""" @pytest.fixture def mock_retriever(self): """Create a mock optimized retriever.""" retriever = AsyncMock() retriever.retrieve = AsyncMock(return_value=RetrievalResult( hits=[ RetrievalHit(text="result 1", score=0.9, source="default", metadata={}), ], diagnostics={"test": True}, )) retriever.health_check = AsyncMock(return_value=True) retriever._two_stage_enabled = True retriever._hybrid_enabled = True return retriever @pytest.fixture def pipeline(self, mock_retriever): """Create a default pipeline with mock retriever.""" return DefaultPipeline(optimized_retriever=mock_retriever) def test_pipeline_name(self, pipeline): """[AC-AISVC-RES-01] Pipeline should have correct name.""" assert pipeline.name == "default_pipeline" def test_pipeline_description(self, pipeline): """[AC-AISVC-RES-01] Pipeline should have description.""" assert "默认" in pipeline.description @pytest.mark.asyncio async def test_retrieve(self, pipeline, mock_retriever): """[AC-AISVC-RES-01] Should retrieve results using optimized retriever.""" retrieval_ctx = RetrievalContext( tenant_id="tenant_1", query="test query", ) pipeline_ctx = PipelineContext(retrieval_ctx=retrieval_ctx) result = await pipeline.retrieve(pipeline_ctx) assert result.pipeline_name == "default_pipeline" assert len(result.hits) == 1 assert result.diagnostics["retriever"] == "OptimizedRetriever" mock_retriever.retrieve.assert_called_once() @pytest.mark.asyncio async def test_retrieve_with_metadata_filter(self, pipeline, mock_retriever): """[AC-AISVC-RES-04] Should apply metadata filter.""" retrieval_ctx = RetrievalContext( tenant_id="tenant_1", query="test query", ) metadata_filter = MetadataFilterResult( filter_dict={"grade": "初一"}, filter_mode=FilterMode.HARD, ) pipeline_ctx = PipelineContext( retrieval_ctx=retrieval_ctx, metadata_filter=metadata_filter, ) result = await pipeline.retrieve(pipeline_ctx) assert result.metadata_filter_applied is True call_args = mock_retriever.retrieve.call_args[0][0] assert call_args.metadata_filter == {"grade": "初一"} @pytest.mark.asyncio async def test_health_check(self, pipeline, mock_retriever): """[AC-AISVC-RES-01] Should check health.""" result = await pipeline.health_check() assert result is True mock_retriever.health_check.assert_called_once() class TestEnhancedPipeline: """[AC-AISVC-RES-02] Tests for enhanced pipeline.""" @pytest.fixture def mock_qdrant_client(self): """Create a mock Qdrant client.""" client = AsyncMock() client.search = AsyncMock(return_value=[ {"id": "1", "score": 0.9, "payload": {"text": "result 1"}}, ]) client.get_client = AsyncMock() return client @pytest.fixture def mock_embedding_provider(self): """Create a mock embedding provider.""" provider = AsyncMock() provider.embed_query = AsyncMock() provider.embed_query.return_value = MagicMock( embedding_full=[0.1] * 768, ) provider.embed = AsyncMock(return_value=[0.1] * 768) return provider @pytest.fixture def pipeline(self, mock_qdrant_client, mock_embedding_provider): """Create an enhanced pipeline with mocks.""" pipeline = EnhancedPipeline(qdrant_client=mock_qdrant_client) pipeline._embedding_provider = mock_embedding_provider return pipeline def test_pipeline_name(self, pipeline): """[AC-AISVC-RES-02] Pipeline should have correct name.""" assert pipeline.name == "enhanced_pipeline" def test_pipeline_description(self, pipeline): """[AC-AISVC-RES-02] Pipeline should have description.""" assert "增强" in pipeline.description @pytest.mark.asyncio async def test_retrieve_basic(self, pipeline): """[AC-AISVC-RES-02] Should retrieve results using hybrid search.""" retrieval_ctx = RetrievalContext( tenant_id="tenant_1", query="test query", ) pipeline_ctx = PipelineContext(retrieval_ctx=retrieval_ctx) result = await pipeline.retrieve(pipeline_ctx) assert result.pipeline_name == "enhanced_pipeline" assert result.diagnostics is not None class TestStrategyRouter: """[AC-AISVC-RES-01~03] Tests for strategy router.""" @pytest.fixture def mock_default_pipeline(self): """Create a mock default pipeline.""" pipeline = AsyncMock(spec=DefaultPipeline) pipeline.name = "default_pipeline" pipeline.retrieve = AsyncMock(return_value=PipelineResult( retrieval_result=RetrievalResult(hits=[]), pipeline_name="default_pipeline", )) return pipeline @pytest.fixture def mock_enhanced_pipeline(self): """Create a mock enhanced pipeline.""" pipeline = AsyncMock(spec=EnhancedPipeline) pipeline.name = "enhanced_pipeline" pipeline.retrieve = AsyncMock(return_value=PipelineResult( retrieval_result=RetrievalResult(hits=[]), pipeline_name="enhanced_pipeline", )) return pipeline @pytest.fixture def router(self, mock_default_pipeline, mock_enhanced_pipeline): """Create a strategy router with mock pipelines.""" config = RetrievalStrategyConfig() return StrategyRouter( config=config, default_pipeline=mock_default_pipeline, enhanced_pipeline=mock_enhanced_pipeline, ) def test_route_default_strategy(self, router): """[AC-AISVC-RES-01] Should route to default strategy by default.""" import asyncio decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_1")) assert decision.strategy == StrategyType.DEFAULT assert decision.reason == "default_strategy" def test_route_enhanced_strategy(self, mock_default_pipeline, mock_enhanced_pipeline): """[AC-AISVC-RES-02] Should route to enhanced strategy when configured.""" config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED) router = StrategyRouter( config=config, default_pipeline=mock_default_pipeline, enhanced_pipeline=mock_enhanced_pipeline, ) import asyncio decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_1")) assert decision.strategy == StrategyType.ENHANCED assert decision.reason == "active_strategy=enhanced" def test_route_grayscale_allowlist(self, mock_default_pipeline, mock_enhanced_pipeline): """[AC-AISVC-RES-03] Should route to enhanced for allowlist tenants.""" config = RetrievalStrategyConfig( active_strategy=StrategyType.DEFAULT, grayscale=GrayscaleConfig(enabled=True, allowlist=["tenant_a"]), ) router = StrategyRouter( config=config, default_pipeline=mock_default_pipeline, enhanced_pipeline=mock_enhanced_pipeline, ) import asyncio decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_a")) assert decision.strategy == StrategyType.ENHANCED assert decision.grayscale_hit is True decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_b")) assert decision.strategy == StrategyType.DEFAULT def test_update_config(self, router): """[AC-AISVC-RES-02] Should update config.""" new_config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED) router.update_config(new_config) assert router.get_config().active_strategy == StrategyType.ENHANCED class TestModeRouter: """[AC-AISVC-RES-09~15] Tests for mode router.""" @pytest.fixture def router(self): """Create a mode router.""" return ModeRouter() def test_decide_react_mode(self): """[AC-AISVC-RES-10] Should decide react when configured.""" config = ModeRouterConfig(runtime_mode=RuntimeMode.REACT) router = ModeRouter(config) decision = router.decide("any query") assert decision.mode == RuntimeMode.REACT assert decision.reason == "runtime_mode=react" def test_decide_direct_mode(self, router): """[AC-AISVC-RES-09] Should decide direct when configured.""" decision = router.decide("any query") assert decision.mode == RuntimeMode.DIRECT assert decision.reason == "runtime_mode=direct" def test_decide_auto_short_query_high_confidence(self): """[AC-AISVC-RES-12] Auto with short query and high confidence should use direct.""" config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO) router = ModeRouter(config) decision = router.decide("短问题", confidence=0.8) assert decision.mode == RuntimeMode.DIRECT def test_decide_auto_low_confidence(self): """[AC-AISVC-RES-13] Auto with low confidence should use react.""" config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO) router = ModeRouter(config) decision = router.decide("any query", confidence=0.3) assert decision.mode == RuntimeMode.REACT def test_should_fallback_to_react_empty_results(self, router): """[AC-AISVC-RES-14] Should fallback to react on empty results.""" result = PipelineResult(retrieval_result=RetrievalResult(hits=[])) assert router.should_fallback_to_react(result) is True def test_should_fallback_to_react_low_score(self, router): """[AC-AISVC-RES-14] Should fallback to react on low score.""" result = PipelineResult( retrieval_result=RetrievalResult( hits=[RetrievalHit(text="test", score=0.1, source="test", metadata={})], ), ) assert router.should_fallback_to_react(result) is True def test_should_not_fallback_to_react_disabled(self): """[AC-AISVC-RES-14] Should not fallback when disabled.""" config = ModeRouterConfig(direct_fallback_on_low_confidence=False) router = ModeRouter(config) result = PipelineResult(retrieval_result=RetrievalResult(hits=[])) assert router.should_fallback_to_react(result) is False class TestRollbackManager: """[AC-AISVC-RES-07] Tests for rollback manager.""" @pytest.fixture def manager(self): """Create a rollback manager.""" return RollbackManager() def test_rollback_from_enhanced(self, manager): """[AC-AISVC-RES-07] Should rollback from enhanced to default.""" config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED) manager.update_config(config) result = manager.rollback( trigger=RollbackTrigger.MANUAL, reason="Testing rollback", ) assert result.success is True assert result.previous_strategy == StrategyType.ENHANCED assert result.current_strategy == StrategyType.DEFAULT assert result.audit_log is not None def test_rollback_already_default(self, manager): """[AC-AISVC-RES-07] Should not rollback when already on default.""" result = manager.rollback( trigger=RollbackTrigger.MANUAL, reason="Testing rollback", ) assert result.success is False assert result.reason == "Already on default strategy" def test_check_and_rollback_latency(self, manager): """[AC-AISVC-RES-08] Should rollback on high latency.""" config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED) manager.update_config(config) result = manager.check_and_rollback( metrics={"latency_ms": 3000.0}, tenant_id="tenant_1", ) assert result is not None assert result.trigger == RollbackTrigger.PERFORMANCE def test_check_and_rollback_error_rate(self, manager): """[AC-AISVC-RES-08] Should rollback on high error rate.""" config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED) manager.update_config(config) result = manager.check_and_rollback( metrics={"error_rate": 0.1}, tenant_id="tenant_1", ) assert result is not None assert result.trigger == RollbackTrigger.ERROR def test_check_and_rollback_ok(self, manager): """[AC-AISVC-RES-08] Should not rollback when metrics are ok.""" config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED) manager.update_config(config) result = manager.check_and_rollback( metrics={"latency_ms": 100.0, "error_rate": 0.01}, tenant_id="tenant_1", ) assert result is None def test_get_audit_logs(self, manager): """[AC-AISVC-RES-07] Should get audit logs.""" config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED) manager.update_config(config) manager.rollback(trigger=RollbackTrigger.MANUAL, reason="Test") logs = manager.get_audit_logs() assert len(logs) == 1 assert logs[0].action == "rollback" def test_record_audit(self, manager): """[AC-AISVC-RES-07] Should record audit log.""" log = manager.record_audit( action="test_action", details={"reason": "Testing"}, tenant_id="tenant_1", ) assert log.action == "test_action" assert log.tenant_id == "tenant_1" class TestSingletonInstances: """Tests for singleton instance getters.""" def test_get_mode_router_singleton(self): """Should return same mode router instance.""" from app.services.retrieval.strategy.mode_router import _mode_router import app.services.retrieval.strategy.mode_router as module module._mode_router = None router1 = get_mode_router() router2 = get_mode_router() assert router1 is router2 def test_get_rollback_manager_singleton(self): """Should return same rollback manager instance.""" from app.services.retrieval.strategy.rollback_manager import _rollback_manager import app.services.retrieval.strategy.rollback_manager as module module._rollback_manager = None manager1 = get_rollback_manager() manager2 = get_rollback_manager() assert manager1 is manager2