""" Unit tests for Strategy Router. [AC-AISVC-RES-01,02,03,07,08] Tests for strategy routing, rollback, and grayscale release. """ import pytest from unittest.mock import AsyncMock, MagicMock, patch from app.services.retrieval.routing_config import ( RagRuntimeMode, StrategyType, RoutingConfig, StrategyContext, ) from app.services.retrieval.strategy_router import ( RollbackRecord, RollbackManager, DefaultPipeline, EnhancedPipeline, StrategyRouter, get_strategy_router, reset_strategy_router, ) class TestRollbackRecord: """[AC-AISVC-RES-07] Tests for RollbackRecord.""" def test_rollback_record_creation(self): """Should create rollback record with all fields.""" record = RollbackRecord( timestamp=1234567890.0, from_strategy=StrategyType.ENHANCED, to_strategy=StrategyType.DEFAULT, reason="Performance issue", tenant_id="tenant_a", request_id="req_123", ) assert record.timestamp == 1234567890.0 assert record.from_strategy == StrategyType.ENHANCED assert record.to_strategy == StrategyType.DEFAULT assert record.reason == "Performance issue" assert record.tenant_id == "tenant_a" assert record.request_id == "req_123" class TestRollbackManager: """[AC-AISVC-RES-07] Tests for RollbackManager.""" @pytest.fixture def manager(self): return RollbackManager(max_records=10) def test_record_rollback(self, manager): """[AC-AISVC-RES-07] Should record rollback event.""" manager.record_rollback( from_strategy=StrategyType.ENHANCED, to_strategy=StrategyType.DEFAULT, reason="Test rollback", tenant_id="tenant_a", ) records = manager.get_recent_rollbacks() assert len(records) == 1 assert records[0].from_strategy == StrategyType.ENHANCED assert records[0].to_strategy == StrategyType.DEFAULT assert records[0].reason == "Test rollback" def test_max_records_limit(self, manager): """Should limit number of records.""" for i in range(15): manager.record_rollback( from_strategy=StrategyType.ENHANCED, to_strategy=StrategyType.DEFAULT, reason=f"Reason {i}", ) records = manager.get_recent_rollbacks(limit=20) assert len(records) == 10 def test_get_rollback_count(self, manager): """Should count rollbacks correctly.""" import time now = time.time() manager.record_rollback( from_strategy=StrategyType.ENHANCED, to_strategy=StrategyType.DEFAULT, reason="Reason 1", ) assert manager.get_rollback_count() == 1 assert manager.get_rollback_count(since_timestamp=now) == 1 assert manager.get_rollback_count(since_timestamp=now + 10) == 0 class TestDefaultPipeline: """[AC-AISVC-RES-01] Tests for DefaultPipeline.""" @pytest.mark.asyncio async def test_execute_uses_optimized_retriever(self): """[AC-AISVC-RES-01] Default pipeline should use OptimizedRetriever.""" pipeline = DefaultPipeline() ctx = StrategyContext( tenant_id="tenant_a", query="Test query", ) mock_result = MagicMock() mock_result.hits = [] with patch( "app.services.retrieval.optimized_retriever.get_optimized_retriever", new_callable=AsyncMock, ) as mock_get_retriever: mock_retriever = AsyncMock() mock_retriever.retrieve.return_value = mock_result mock_get_retriever.return_value = mock_retriever result = await pipeline.execute(ctx) assert result == mock_result mock_retriever.retrieve.assert_called_once() class TestEnhancedPipeline: """[AC-AISVC-RES-02] Tests for EnhancedPipeline.""" @pytest.mark.asyncio async def test_execute_creates_optimized_retriever(self): """[AC-AISVC-RES-02] Enhanced pipeline should create OptimizedRetriever with enhanced config.""" config = RoutingConfig(strategy=StrategyType.ENHANCED) pipeline = EnhancedPipeline(config=config) ctx = StrategyContext( tenant_id="tenant_a", query="Test query", ) mock_result = MagicMock() mock_result.hits = [] mock_retriever = AsyncMock() mock_retriever.retrieve.return_value = mock_result with patch( "app.services.retrieval.optimized_retriever.OptimizedRetriever", return_value=mock_retriever, ) as mock_retriever_class: result = await pipeline.execute(ctx) mock_retriever_class.assert_called_once_with( two_stage_enabled=True, hybrid_enabled=True, ) assert result == mock_result class TestStrategyRouter: """[AC-AISVC-RES-01,02,03,07,08] Tests for StrategyRouter.""" @pytest.fixture def router(self): reset_strategy_router() return StrategyRouter() def test_initial_state(self, router): """[AC-AISVC-RES-01] Initial state should be default strategy.""" assert router.current_strategy == StrategyType.DEFAULT assert router.config.enabled is True def test_route_default_strategy(self, router): """[AC-AISVC-RES-01] Should route to default strategy by default.""" ctx = StrategyContext(tenant_id="tenant_a", query="Test query") result = router.route(ctx) assert result.strategy == StrategyType.DEFAULT assert router.current_strategy == StrategyType.DEFAULT def test_route_enhanced_strategy_explicit(self, router): """[AC-AISVC-RES-02] Should route to enhanced when explicitly configured.""" router._config.strategy = StrategyType.ENHANCED ctx = StrategyContext(tenant_id="tenant_a", query="Test query") result = router.route(ctx) assert result.strategy == StrategyType.ENHANCED assert router.current_strategy == StrategyType.ENHANCED def test_route_grayscale_allowlist(self, router): """[AC-AISVC-RES-03] Should route to enhanced for tenants in allowlist.""" router._config.strategy = StrategyType.ENHANCED router._config.grayscale_allowlist = ["tenant_a", "tenant_b"] ctx_a = StrategyContext(tenant_id="tenant_a", query="Test query") ctx_c = StrategyContext(tenant_id="tenant_c", query="Test query") result_a = router.route(ctx_a) result_c = router.route(ctx_c) assert result_a.strategy == StrategyType.ENHANCED assert result_c.strategy == StrategyType.ENHANCED def test_route_disabled_strategy(self, router): """Should use default when strategy is disabled.""" router._strategy_enabled = False ctx = StrategyContext(tenant_id="tenant_a", query="Test query") result = router.route(ctx) assert result.strategy == StrategyType.DEFAULT def test_update_config(self, router): """[AC-AISVC-RES-15] Should update configuration.""" new_config = RoutingConfig( strategy=StrategyType.ENHANCED, rag_runtime_mode=RagRuntimeMode.REACT, ) router.update_config(new_config) assert router.config.strategy == StrategyType.ENHANCED assert router.config.rag_runtime_mode == RagRuntimeMode.REACT def test_rollback(self, router): """[AC-AISVC-RES-07] Should rollback to default strategy.""" router._current_strategy = StrategyType.ENHANCED router._config.strategy = StrategyType.ENHANCED router.rollback(reason="Test rollback", tenant_id="tenant_a") assert router.current_strategy == StrategyType.DEFAULT assert router.config.strategy == StrategyType.DEFAULT records = router.get_rollback_records() assert len(records) == 1 assert records[0].reason == "Test rollback" def test_rollback_from_default_no_op(self, router): """[AC-AISVC-RES-07] Rollback from default should be no-op.""" router.rollback(reason="Test rollback") assert router.current_strategy == StrategyType.DEFAULT records = router.get_rollback_records() assert len(records) == 0 def test_validate_config(self, router): """[AC-AISVC-RES-06] Should validate configuration.""" is_valid, errors = router.validate_config() assert is_valid is True assert len(errors) == 0 @pytest.mark.asyncio async def test_execute_default_strategy(self, router): """[AC-AISVC-RES-01] Should execute default pipeline.""" ctx = StrategyContext(tenant_id="tenant_a", query="Test query") mock_result = MagicMock() mock_result.hits = [] with patch.object( router._default_pipeline, "execute", new_callable=AsyncMock ) as mock_execute: mock_execute.return_value = mock_result result, strategy_result = await router.execute(ctx) assert result == mock_result assert strategy_result.strategy == StrategyType.DEFAULT mock_execute.assert_called_once_with(ctx) @pytest.mark.asyncio async def test_execute_enhanced_strategy(self, router): """[AC-AISVC-RES-02] Should execute enhanced pipeline.""" router._config.strategy = StrategyType.ENHANCED ctx = StrategyContext(tenant_id="tenant_a", query="Test query") mock_result = MagicMock() mock_result.hits = [] with patch.object( router._enhanced_pipeline, "execute", new_callable=AsyncMock ) as mock_execute: mock_execute.return_value = mock_result result, strategy_result = await router.execute(ctx) assert result == mock_result assert strategy_result.strategy == StrategyType.ENHANCED mock_execute.assert_called_once_with(ctx) @pytest.mark.asyncio async def test_execute_fallback_on_error(self, router): """[AC-AISVC-RES-07] Should fallback to default on enhanced failure.""" router._config.strategy = StrategyType.ENHANCED ctx = StrategyContext(tenant_id="tenant_a", query="Test query") mock_result = MagicMock() mock_result.hits = [] with patch.object( router._enhanced_pipeline, "execute", new_callable=AsyncMock ) as mock_enhanced: mock_enhanced.side_effect = Exception("Enhanced pipeline failed") with patch.object( router._default_pipeline, "execute", new_callable=AsyncMock ) as mock_default: mock_default.return_value = mock_result result, strategy_result = await router.execute(ctx) assert result == mock_result assert strategy_result.strategy == StrategyType.DEFAULT assert strategy_result.should_fallback is True assert "Enhanced pipeline failed" in strategy_result.fallback_reason class TestSingletonInstances: """Tests for singleton instance getters.""" def test_get_strategy_router_singleton(self): """Should return same router instance.""" reset_strategy_router() router1 = get_strategy_router() router2 = get_strategy_router() assert router1 is router2 def test_reset_strategy_router(self): """Should create new instance after reset.""" router1 = get_strategy_router() reset_strategy_router() router2 = get_strategy_router() assert router1 is not router2