diff --git a/ai-service/tests/test_mode_router.py b/ai-service/tests/test_mode_router.py new file mode 100644 index 0000000..f54da75 --- /dev/null +++ b/ai-service/tests/test_mode_router.py @@ -0,0 +1,297 @@ +""" +Unit tests for Mode Router. +[AC-AISVC-RES-09,10,11,12,13,14] Tests for mode routing, complexity analysis, and fallback. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from app.services.retrieval.routing_config import ( + RagRuntimeMode, + RoutingConfig, + StrategyContext, +) +from app.services.retrieval.mode_router import ( + ComplexityAnalyzer, + ModeRouteResult, + ModeRouter, + get_mode_router, + reset_mode_router, +) + + +class TestComplexityAnalyzer: + """[AC-AISVC-RES-12,13] Tests for ComplexityAnalyzer.""" + + @pytest.fixture + def analyzer(self): + return ComplexityAnalyzer() + + def test_analyze_empty_query(self, analyzer): + """Empty query should have zero complexity.""" + score = analyzer.analyze("") + + assert score == 0.0 + + def test_analyze_short_query(self, analyzer): + """Short query should have low complexity.""" + score = analyzer.analyze("简单问题") + + assert score < 0.3 + + def test_analyze_long_query(self, analyzer): + """Long query should have higher complexity.""" + long_query = "这是一个很长的问题" * 20 + score = analyzer.analyze(long_query) + + assert score >= 0.3 + + def test_analyze_multiple_conditions(self, analyzer): + """[AC-AISVC-RES-13] Query with multiple conditions should have high complexity.""" + query = "查询订单状态和物流信息以及退款进度" + score = analyzer.analyze(query) + + assert score >= 0.2 + + def test_analyze_reasoning_patterns(self, analyzer): + """Query with reasoning patterns should have higher complexity.""" + query = "为什么订单会被取消?原因是什么?" + score = analyzer.analyze(query) + + assert score >= 0.15 + + def test_analyze_cross_domain_patterns(self, analyzer): + """Cross-domain query should have higher complexity.""" + query = "汇总所有部门的销售数据" + score = analyzer.analyze(query) + + assert score >= 0.15 + + def test_analyze_multiple_questions(self, analyzer): + """Multiple questions should increase complexity.""" + query = "什么是价格?如何购买?" + score = analyzer.analyze(query) + + assert score >= 0.1 + + +class TestModeRouteResult: + """Tests for ModeRouteResult.""" + + def test_result_creation(self): + """Should create result with all fields.""" + result = ModeRouteResult( + mode=RagRuntimeMode.REACT, + confidence=0.5, + complexity_score=0.7, + should_fallback_to_react=True, + fallback_reason="Low confidence", + diagnostics={"key": "value"}, + ) + + assert result.mode == RagRuntimeMode.REACT + assert result.confidence == 0.5 + assert result.complexity_score == 0.7 + assert result.should_fallback_to_react is True + assert result.fallback_reason == "Low confidence" + assert result.diagnostics == {"key": "value"} + + def test_result_defaults(self): + """Should create result with default values.""" + result = ModeRouteResult( + mode=RagRuntimeMode.DIRECT, + confidence=0.8, + complexity_score=0.2, + ) + + assert result.should_fallback_to_react is False + assert result.fallback_reason is None + assert result.diagnostics == {} + + +class TestModeRouter: + """[AC-AISVC-RES-09,10,11,12,13,14] Tests for ModeRouter.""" + + @pytest.fixture + def router(self): + reset_mode_router() + return ModeRouter() + + def test_initial_config(self, router): + """Should initialize with default configuration.""" + assert router.config.rag_runtime_mode == RagRuntimeMode.AUTO + + def test_route_direct_mode(self, router): + """[AC-AISVC-RES-09] Should route to direct mode when configured.""" + router._config.rag_runtime_mode = RagRuntimeMode.DIRECT + + ctx = StrategyContext(tenant_id="tenant_a", query="Test query") + + result = router.route(ctx) + + assert result.mode == RagRuntimeMode.DIRECT + + def test_route_react_mode(self, router): + """[AC-AISVC-RES-10] Should route to react mode when configured.""" + router._config.rag_runtime_mode = RagRuntimeMode.REACT + + ctx = StrategyContext(tenant_id="tenant_a", query="Test query") + + result = router.route(ctx) + + assert result.mode == RagRuntimeMode.REACT + + def test_route_auto_mode_direct_conditions(self, router): + """[AC-AISVC-RES-11,12] Auto mode should select direct for simple queries.""" + router._config.rag_runtime_mode = RagRuntimeMode.AUTO + + ctx = StrategyContext( + tenant_id="tenant_a", + query="简单问题", + metadata_confidence=0.9, + complexity_score=0.1, + ) + + result = router.route(ctx) + + assert result.mode == RagRuntimeMode.DIRECT + + def test_route_auto_mode_react_low_confidence(self, router): + """[AC-AISVC-RES-11,12] Auto mode should select react for low confidence.""" + router._config.rag_runtime_mode = RagRuntimeMode.AUTO + router._config.react_trigger_confidence_threshold = 0.6 + + ctx = StrategyContext( + tenant_id="tenant_a", + query="Test query", + metadata_confidence=0.4, + complexity_score=0.2, + ) + + result = router.route(ctx) + + assert result.mode == RagRuntimeMode.REACT + + def test_route_auto_mode_react_high_complexity(self, router): + """[AC-AISVC-RES-11,13] Auto mode should select react for high complexity.""" + router._config.rag_runtime_mode = RagRuntimeMode.AUTO + router._config.react_trigger_complexity_score = 0.5 + + ctx = StrategyContext( + tenant_id="tenant_a", + query="Test query", + metadata_confidence=0.8, + complexity_score=0.7, + ) + + result = router.route(ctx) + + assert result.mode == RagRuntimeMode.REACT + + def test_update_config(self, router): + """[AC-AISVC-RES-15] Should update configuration.""" + new_config = RoutingConfig( + rag_runtime_mode=RagRuntimeMode.REACT, + react_max_steps=7, + ) + + router.update_config(new_config) + + assert router.config.rag_runtime_mode == RagRuntimeMode.REACT + assert router.config.react_max_steps == 7 + + @pytest.mark.asyncio + async def test_execute_direct(self, router): + """[AC-AISVC-RES-09] Should execute direct retrieval.""" + ctx = StrategyContext(tenant_id="tenant_a", query="Test query") + + mock_result = MagicMock() + mock_result.hits = [] + + with patch.object( + router._direct_executor, "execute", new_callable=AsyncMock + ) as mock_execute: + mock_execute.return_value = mock_result + + result = await router.execute_direct(ctx) + + assert result == mock_result + + @pytest.mark.asyncio + async def test_execute_with_fallback_no_fallback(self, router): + """[AC-AISVC-RES-14] Should not fallback when confidence is high.""" + router._config.rag_runtime_mode = RagRuntimeMode.DIRECT + router._config.direct_fallback_on_low_confidence = True + router._config.direct_fallback_confidence_threshold = 0.4 + + ctx = StrategyContext(tenant_id="tenant_a", query="Test query") + + mock_result = MagicMock() + mock_hit = MagicMock() + mock_hit.score = 0.8 + mock_result.hits = [mock_hit] + + with patch.object( + router._direct_executor, "execute", new_callable=AsyncMock + ) as mock_execute: + mock_execute.return_value = mock_result + + result, answer, route_result = await router.execute_with_fallback(ctx) + + assert result == mock_result + assert answer is None + assert route_result.mode == RagRuntimeMode.DIRECT + assert route_result.should_fallback_to_react is False + + @pytest.mark.asyncio + async def test_execute_with_fallback_triggered(self, router): + """[AC-AISVC-RES-14] Should fallback to react when confidence is low.""" + router._config.rag_runtime_mode = RagRuntimeMode.DIRECT + router._config.direct_fallback_on_low_confidence = True + router._config.direct_fallback_confidence_threshold = 0.4 + + ctx = StrategyContext(tenant_id="tenant_a", query="Test query") + + mock_result = MagicMock() + mock_hit = MagicMock() + mock_hit.score = 0.2 + mock_result.hits = [mock_hit] + + with patch.object( + router._direct_executor, "execute", new_callable=AsyncMock + ) as mock_direct: + mock_direct.return_value = mock_result + + with patch.object( + router._react_executor, "execute", new_callable=AsyncMock + ) as mock_react: + mock_react.return_value = ("Final answer", None, {}) + + result, answer, route_result = await router.execute_with_fallback(ctx) + + assert result is None + assert answer == "Final answer" + assert route_result.mode == RagRuntimeMode.REACT + assert route_result.should_fallback_to_react is True + assert route_result.fallback_reason == "low_confidence" + + +class TestSingletonInstances: + """Tests for singleton instance getters.""" + + def test_get_mode_router_singleton(self): + """Should return same router instance.""" + reset_mode_router() + + router1 = get_mode_router() + router2 = get_mode_router() + + assert router1 is router2 + + def test_reset_mode_router(self): + """Should create new instance after reset.""" + router1 = get_mode_router() + reset_mode_router() + router2 = get_mode_router() + + assert router1 is not router2 diff --git a/ai-service/tests/test_routing_config.py b/ai-service/tests/test_routing_config.py new file mode 100644 index 0000000..adc16f5 --- /dev/null +++ b/ai-service/tests/test_routing_config.py @@ -0,0 +1,253 @@ +""" +Unit tests for Routing Configuration. +[AC-AISVC-RES-01~15] Tests for strategy routing configuration models. +""" + +import pytest + +from app.services.retrieval.routing_config import ( + RagRuntimeMode, + StrategyType, + RoutingConfig, + StrategyContext, + StrategyResult, +) + + +class TestStrategyType: + """[AC-AISVC-RES-01,02] Tests for StrategyType enum.""" + + def test_strategy_type_default(self): + """[AC-AISVC-RES-01] Default strategy type should exist.""" + assert StrategyType.DEFAULT.value == "default" + + def test_strategy_type_enhanced(self): + """[AC-AISVC-RES-02] Enhanced strategy type should exist.""" + assert StrategyType.ENHANCED.value == "enhanced" + + def test_strategy_type_from_string(self): + """Should create StrategyType from string.""" + assert StrategyType("default") == StrategyType.DEFAULT + assert StrategyType("enhanced") == StrategyType.ENHANCED + + +class TestRagRuntimeMode: + """[AC-AISVC-RES-09,10,11] Tests for RagRuntimeMode enum.""" + + def test_mode_direct(self): + """[AC-AISVC-RES-09] Direct mode should exist.""" + assert RagRuntimeMode.DIRECT.value == "direct" + + def test_mode_react(self): + """[AC-AISVC-RES-10] React mode should exist.""" + assert RagRuntimeMode.REACT.value == "react" + + def test_mode_auto(self): + """[AC-AISVC-RES-11] Auto mode should exist.""" + assert RagRuntimeMode.AUTO.value == "auto" + + def test_mode_from_string(self): + """Should create RagRuntimeMode from string.""" + assert RagRuntimeMode("direct") == RagRuntimeMode.DIRECT + assert RagRuntimeMode("react") == RagRuntimeMode.REACT + assert RagRuntimeMode("auto") == RagRuntimeMode.AUTO + + +class TestRoutingConfig: + """[AC-AISVC-RES-01~15] Tests for RoutingConfig.""" + + def test_default_config(self): + """[AC-AISVC-RES-01] Default config should use default strategy and auto mode.""" + config = RoutingConfig() + + assert config.enabled is True + assert config.strategy == StrategyType.DEFAULT + assert config.rag_runtime_mode == RagRuntimeMode.AUTO + assert config.grayscale_percentage == 0.0 + assert config.grayscale_allowlist == [] + + def test_config_with_custom_values(self): + """[AC-AISVC-RES-02,03] Config should accept custom values.""" + config = RoutingConfig( + enabled=False, + strategy=StrategyType.ENHANCED, + rag_runtime_mode=RagRuntimeMode.REACT, + grayscale_percentage=0.3, + grayscale_allowlist=["tenant_a", "tenant_b"], + ) + + assert config.enabled is False + assert config.strategy == StrategyType.ENHANCED + assert config.rag_runtime_mode == RagRuntimeMode.REACT + assert config.grayscale_percentage == 0.3 + assert "tenant_a" in config.grayscale_allowlist + + def test_should_use_enhanced_strategy_default(self): + """[AC-AISVC-RES-01] Default strategy should not use enhanced.""" + config = RoutingConfig() + + assert config.should_use_enhanced_strategy("tenant_a") is False + + def test_should_use_enhanced_strategy_explicit_enhanced(self): + """[AC-AISVC-RES-02] Explicit enhanced strategy should use enhanced for all tenants.""" + config = RoutingConfig(strategy=StrategyType.ENHANCED) + + assert config.should_use_enhanced_strategy("tenant_a") is True + assert config.should_use_enhanced_strategy("tenant_b") is True + assert config.should_use_enhanced_strategy(None) is True + + def test_should_fallback_direct_to_react_disabled(self): + """[AC-AISVC-RES-14] Fallback should be disabled when configured.""" + config = RoutingConfig(direct_fallback_on_low_confidence=False) + + assert config.should_fallback_direct_to_react(0.1) is False + assert config.should_fallback_direct_to_react(0.0) is False + + def test_should_fallback_direct_to_react_enabled(self): + """[AC-AISVC-RES-14] Fallback should trigger on low confidence.""" + config = RoutingConfig( + direct_fallback_on_low_confidence=True, + direct_fallback_confidence_threshold=0.4, + ) + + assert config.should_fallback_direct_to_react(0.3) is True + assert config.should_fallback_direct_to_react(0.4) is False + assert config.should_fallback_direct_to_react(0.5) is False + + def test_should_trigger_react_in_auto_mode_low_confidence(self): + """[AC-AISVC-RES-12] Low confidence should trigger react in auto mode.""" + config = RoutingConfig( + react_trigger_confidence_threshold=0.6, + ) + + assert config.should_trigger_react_in_auto_mode( + confidence=0.5, complexity_score=0.3 + ) is True + assert config.should_trigger_react_in_auto_mode( + confidence=0.7, complexity_score=0.3 + ) is False + + def test_should_trigger_react_in_auto_mode_high_complexity(self): + """[AC-AISVC-RES-13] High complexity should trigger react in auto mode.""" + config = RoutingConfig( + react_trigger_complexity_score=0.5, + ) + + assert config.should_trigger_react_in_auto_mode( + confidence=0.8, complexity_score=0.6 + ) is True + assert config.should_trigger_react_in_auto_mode( + confidence=0.8, complexity_score=0.4 + ) is False + + def test_validate_valid_config(self): + """[AC-AISVC-RES-06] Valid config should pass validation.""" + config = RoutingConfig() + + is_valid, errors = config.validate() + + assert is_valid is True + assert len(errors) == 0 + + def test_validate_invalid_grayscale_percentage(self): + """[AC-AISVC-RES-06] Invalid grayscale percentage should fail validation.""" + config = RoutingConfig(grayscale_percentage=1.5) + + is_valid, errors = config.validate() + + assert is_valid is False + assert any("grayscale_percentage" in e for e in errors) + + def test_validate_invalid_confidence_threshold(self): + """[AC-AISVC-RES-06] Invalid confidence threshold should fail validation.""" + config = RoutingConfig(react_trigger_confidence_threshold=1.5) + + is_valid, errors = config.validate() + + assert is_valid is False + assert any("react_trigger_confidence_threshold" in e for e in errors) + + def test_validate_invalid_react_max_steps(self): + """[AC-AISVC-RES-06] Invalid react max steps should fail validation.""" + config = RoutingConfig(react_max_steps=2) + + is_valid, errors = config.validate() + + assert is_valid is False + assert any("react_max_steps" in e for e in errors) + + def test_validate_invalid_performance_budget(self): + """[AC-AISVC-RES-06] Invalid performance budget should fail validation.""" + config = RoutingConfig(performance_budget_ms=500) + + is_valid, errors = config.validate() + + assert is_valid is False + assert any("performance_budget_ms" in e for e in errors) + + +class TestStrategyContext: + """[AC-AISVC-RES-01~15] Tests for StrategyContext.""" + + def test_context_creation(self): + """Should create context with all fields.""" + ctx = StrategyContext( + tenant_id="tenant_a", + query="Test query", + metadata_filter={"category": "product"}, + metadata_confidence=0.8, + complexity_score=0.3, + kb_ids=["kb_1", "kb_2"], + top_k=10, + ) + + assert ctx.tenant_id == "tenant_a" + assert ctx.query == "Test query" + assert ctx.metadata_filter == {"category": "product"} + assert ctx.metadata_confidence == 0.8 + assert ctx.complexity_score == 0.3 + assert ctx.kb_ids == ["kb_1", "kb_2"] + assert ctx.top_k == 10 + + def test_context_minimal(self): + """Should create context with minimal fields.""" + ctx = StrategyContext(tenant_id="tenant_a", query="Test query") + + assert ctx.tenant_id == "tenant_a" + assert ctx.query == "Test query" + assert ctx.metadata_filter is None + assert ctx.metadata_confidence == 1.0 + assert ctx.complexity_score == 0.0 + assert ctx.kb_ids is None + assert ctx.top_k == 5 + + +class TestStrategyResult: + """[AC-AISVC-RES-01,02] Tests for StrategyResult.""" + + def test_result_creation(self): + """Should create result with all fields.""" + result = StrategyResult( + strategy=StrategyType.ENHANCED, + mode=RagRuntimeMode.REACT, + should_fallback=True, + fallback_reason="Low confidence", + diagnostics={"key": "value"}, + ) + + assert result.strategy == StrategyType.ENHANCED + assert result.mode == RagRuntimeMode.REACT + assert result.should_fallback is True + assert result.fallback_reason == "Low confidence" + assert result.diagnostics == {"key": "value"} + + def test_result_defaults(self): + """Should create result with default values.""" + result = StrategyResult( + strategy=StrategyType.DEFAULT, + mode=RagRuntimeMode.AUTO, + ) + + assert result.should_fallback is False + assert result.fallback_reason is None + assert result.diagnostics == {} diff --git a/ai-service/tests/test_strategy_integration.py b/ai-service/tests/test_strategy_integration.py new file mode 100644 index 0000000..affd181 --- /dev/null +++ b/ai-service/tests/test_strategy_integration.py @@ -0,0 +1,256 @@ +""" +Unit tests for Retrieval Strategy Integration. +[AC-AISVC-RES-01~15] Tests for integrated strategy and mode routing. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from app.services.retrieval.routing_config import ( + RagRuntimeMode, + StrategyType, + RoutingConfig, + StrategyContext, +) +from app.services.retrieval.strategy_integration import ( + RetrievalStrategyResult, + RetrievalStrategyIntegration, + get_retrieval_strategy_integration, + reset_retrieval_strategy_integration, +) + + +class TestRetrievalStrategyResult: + """Tests for RetrievalStrategyResult.""" + + def test_result_creation(self): + """Should create result with all fields.""" + result = RetrievalStrategyResult( + retrieval_result=None, + final_answer="Test answer", + strategy=StrategyType.ENHANCED, + mode=RagRuntimeMode.REACT, + should_fallback=True, + fallback_reason="Low confidence", + diagnostics={"key": "value"}, + duration_ms=100, + ) + + assert result.retrieval_result is None + assert result.final_answer == "Test answer" + assert result.strategy == StrategyType.ENHANCED + assert result.mode == RagRuntimeMode.REACT + assert result.should_fallback is True + assert result.fallback_reason == "Low confidence" + assert result.diagnostics == {"key": "value"} + assert result.duration_ms == 100 + + def test_result_defaults(self): + """Should create result with default values.""" + result = RetrievalStrategyResult( + retrieval_result=None, + final_answer=None, + strategy=StrategyType.DEFAULT, + mode=RagRuntimeMode.DIRECT, + ) + + assert result.should_fallback is False + assert result.fallback_reason is None + assert result.mode_route_result is None + assert result.diagnostics == {} + assert result.duration_ms == 0 + + +class TestRetrievalStrategyIntegration: + """[AC-AISVC-RES-01~15] Tests for RetrievalStrategyIntegration.""" + + @pytest.fixture + def integration(self): + reset_retrieval_strategy_integration() + return RetrievalStrategyIntegration() + + def test_initial_state(self, integration): + """Should initialize with default configuration.""" + assert integration.config.strategy == StrategyType.DEFAULT + assert integration.config.rag_runtime_mode == RagRuntimeMode.AUTO + + def test_update_config(self, integration): + """[AC-AISVC-RES-15] Should update all configurations.""" + new_config = RoutingConfig( + strategy=StrategyType.ENHANCED, + rag_runtime_mode=RagRuntimeMode.REACT, + react_max_steps=7, + ) + + integration.update_config(new_config) + + assert integration.config.strategy == StrategyType.ENHANCED + assert integration.config.rag_runtime_mode == RagRuntimeMode.REACT + + def test_get_current_strategy(self, integration): + """Should return current strategy from router.""" + strategy = integration.get_current_strategy() + + assert strategy == StrategyType.DEFAULT + + def test_get_rollback_records(self, integration): + """Should return rollback records from router.""" + records = integration.get_rollback_records() + + assert isinstance(records, list) + + def test_validate_config(self, integration): + """[AC-AISVC-RES-06] Should validate configuration.""" + is_valid, errors = integration.validate_config() + + assert is_valid is True + assert len(errors) == 0 + + @pytest.mark.asyncio + async def test_execute_direct_mode(self, integration): + """[AC-AISVC-RES-09] Should execute direct mode.""" + integration._config.rag_runtime_mode = RagRuntimeMode.DIRECT + + ctx = StrategyContext(tenant_id="tenant_a", query="Test query") + + mock_result = MagicMock() + mock_result.hits = [] + + mock_route_result = MagicMock() + mock_route_result.mode = RagRuntimeMode.DIRECT + + with patch.object( + integration._mode_router, "route", return_value=mock_route_result + ): + with patch.object( + integration._mode_router, "execute_with_fallback", new_callable=AsyncMock + ) as mock_execute: + mock_execute.return_value = (mock_result, None, mock_route_result) + + result = await integration.execute(ctx) + + assert result.retrieval_result == mock_result + assert result.final_answer is None + assert result.mode == RagRuntimeMode.DIRECT + + @pytest.mark.asyncio + async def test_execute_react_mode(self, integration): + """[AC-AISVC-RES-10] Should execute react mode.""" + integration._config.rag_runtime_mode = RagRuntimeMode.REACT + + ctx = StrategyContext(tenant_id="tenant_a", query="Test query") + + mock_route_result = MagicMock() + mock_route_result.mode = RagRuntimeMode.REACT + + with patch.object( + integration._mode_router, "route", return_value=mock_route_result + ): + with patch.object( + integration._mode_router, "execute_react", new_callable=AsyncMock + ) as mock_execute: + mock_execute.return_value = ("Final answer", None, {}) + + result = await integration.execute(ctx) + + assert result.retrieval_result is None + assert result.final_answer == "Final answer" + assert result.mode == RagRuntimeMode.REACT + + @pytest.mark.asyncio + async def test_execute_with_fallback(self, integration): + """[AC-AISVC-RES-14] Should handle fallback from direct to react.""" + integration._config.rag_runtime_mode = RagRuntimeMode.DIRECT + integration._config.direct_fallback_on_low_confidence = True + integration._config.direct_fallback_confidence_threshold = 0.4 + + ctx = StrategyContext(tenant_id="tenant_a", query="Test query") + + mock_route_result = MagicMock() + mock_route_result.mode = RagRuntimeMode.DIRECT + mock_route_result.should_fallback_to_react = True + mock_route_result.fallback_reason = "low_confidence" + + with patch.object( + integration._mode_router, "route", return_value=mock_route_result + ): + with patch.object( + integration._mode_router, "execute_with_fallback", new_callable=AsyncMock + ) as mock_execute: + mock_execute.return_value = (None, "Fallback answer", mock_route_result) + + result = await integration.execute(ctx) + + assert result.retrieval_result is None + assert result.final_answer == "Fallback answer" + assert result.should_fallback is True + + @pytest.mark.asyncio + async def test_execute_includes_diagnostics(self, integration): + """Should include diagnostics in result.""" + ctx = StrategyContext(tenant_id="tenant_a", query="Test query") + + mock_result = MagicMock() + mock_result.hits = [] + + mock_route_result = MagicMock() + mock_route_result.mode = RagRuntimeMode.DIRECT + mock_route_result.diagnostics = {"mode_key": "mode_value"} + + with patch.object( + integration._mode_router, "route", return_value=mock_route_result + ): + with patch.object( + integration._mode_router, "execute_with_fallback", new_callable=AsyncMock + ) as mock_execute: + mock_execute.return_value = (mock_result, None, mock_route_result) + + result = await integration.execute(ctx) + + assert "strategy_diagnostics" in result.diagnostics + assert "mode_diagnostics" in result.diagnostics + assert "duration_ms" in result.diagnostics + + @pytest.mark.asyncio + async def test_execute_tracks_duration(self, integration): + """Should track execution duration.""" + ctx = StrategyContext(tenant_id="tenant_a", query="Test query") + + mock_result = MagicMock() + mock_result.hits = [] + + mock_route_result = MagicMock() + mock_route_result.mode = RagRuntimeMode.DIRECT + + with patch.object( + integration._mode_router, "route", return_value=mock_route_result + ): + with patch.object( + integration._mode_router, "execute_with_fallback", new_callable=AsyncMock + ) as mock_execute: + mock_execute.return_value = (mock_result, None, mock_route_result) + + result = await integration.execute(ctx) + + assert result.duration_ms >= 0 + + +class TestSingletonInstances: + """Tests for singleton instance getters.""" + + def test_get_retrieval_strategy_integration_singleton(self): + """Should return same integration instance.""" + reset_retrieval_strategy_integration() + + integration1 = get_retrieval_strategy_integration() + integration2 = get_retrieval_strategy_integration() + + assert integration1 is integration2 + + def test_reset_retrieval_strategy_integration(self): + """Should create new instance after reset.""" + integration1 = get_retrieval_strategy_integration() + reset_retrieval_strategy_integration() + integration2 = get_retrieval_strategy_integration() + + assert integration1 is not integration2 diff --git a/ai-service/tests/test_strategy_router.py b/ai-service/tests/test_strategy_router.py new file mode 100644 index 0000000..850b8e9 --- /dev/null +++ b/ai-service/tests/test_strategy_router.py @@ -0,0 +1,344 @@ +""" +Unit tests for Strategy Router. +[AC-AISVC-RES-01,02,03,07,08] Tests for strategy routing, rollback, and grayscale release. +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from app.services.retrieval.routing_config import ( + RagRuntimeMode, + StrategyType, + RoutingConfig, + StrategyContext, +) +from app.services.retrieval.strategy_router import ( + RollbackRecord, + RollbackManager, + DefaultPipeline, + EnhancedPipeline, + StrategyRouter, + get_strategy_router, + reset_strategy_router, +) + + +class TestRollbackRecord: + """[AC-AISVC-RES-07] Tests for RollbackRecord.""" + + def test_rollback_record_creation(self): + """Should create rollback record with all fields.""" + record = RollbackRecord( + timestamp=1234567890.0, + from_strategy=StrategyType.ENHANCED, + to_strategy=StrategyType.DEFAULT, + reason="Performance issue", + tenant_id="tenant_a", + request_id="req_123", + ) + + assert record.timestamp == 1234567890.0 + assert record.from_strategy == StrategyType.ENHANCED + assert record.to_strategy == StrategyType.DEFAULT + assert record.reason == "Performance issue" + assert record.tenant_id == "tenant_a" + assert record.request_id == "req_123" + + +class TestRollbackManager: + """[AC-AISVC-RES-07] Tests for RollbackManager.""" + + @pytest.fixture + def manager(self): + return RollbackManager(max_records=10) + + def test_record_rollback(self, manager): + """[AC-AISVC-RES-07] Should record rollback event.""" + manager.record_rollback( + from_strategy=StrategyType.ENHANCED, + to_strategy=StrategyType.DEFAULT, + reason="Test rollback", + tenant_id="tenant_a", + ) + + records = manager.get_recent_rollbacks() + assert len(records) == 1 + assert records[0].from_strategy == StrategyType.ENHANCED + assert records[0].to_strategy == StrategyType.DEFAULT + assert records[0].reason == "Test rollback" + + def test_max_records_limit(self, manager): + """Should limit number of records.""" + for i in range(15): + manager.record_rollback( + from_strategy=StrategyType.ENHANCED, + to_strategy=StrategyType.DEFAULT, + reason=f"Reason {i}", + ) + + records = manager.get_recent_rollbacks(limit=20) + assert len(records) == 10 + + def test_get_rollback_count(self, manager): + """Should count rollbacks correctly.""" + import time + + now = time.time() + + manager.record_rollback( + from_strategy=StrategyType.ENHANCED, + to_strategy=StrategyType.DEFAULT, + reason="Reason 1", + ) + + assert manager.get_rollback_count() == 1 + assert manager.get_rollback_count(since_timestamp=now) == 1 + assert manager.get_rollback_count(since_timestamp=now + 10) == 0 + + +class TestDefaultPipeline: + """[AC-AISVC-RES-01] Tests for DefaultPipeline.""" + + @pytest.mark.asyncio + async def test_execute_uses_optimized_retriever(self): + """[AC-AISVC-RES-01] Default pipeline should use OptimizedRetriever.""" + pipeline = DefaultPipeline() + + ctx = StrategyContext( + tenant_id="tenant_a", + query="Test query", + ) + + mock_result = MagicMock() + mock_result.hits = [] + + with patch( + "app.services.retrieval.optimized_retriever.get_optimized_retriever", + new_callable=AsyncMock, + ) as mock_get_retriever: + mock_retriever = AsyncMock() + mock_retriever.retrieve.return_value = mock_result + mock_get_retriever.return_value = mock_retriever + + result = await pipeline.execute(ctx) + + assert result == mock_result + mock_retriever.retrieve.assert_called_once() + + +class TestEnhancedPipeline: + """[AC-AISVC-RES-02] Tests for EnhancedPipeline.""" + + @pytest.mark.asyncio + async def test_execute_creates_optimized_retriever(self): + """[AC-AISVC-RES-02] Enhanced pipeline should create OptimizedRetriever with enhanced config.""" + config = RoutingConfig(strategy=StrategyType.ENHANCED) + pipeline = EnhancedPipeline(config=config) + + ctx = StrategyContext( + tenant_id="tenant_a", + query="Test query", + ) + + mock_result = MagicMock() + mock_result.hits = [] + + mock_retriever = AsyncMock() + mock_retriever.retrieve.return_value = mock_result + + with patch( + "app.services.retrieval.optimized_retriever.OptimizedRetriever", + return_value=mock_retriever, + ) as mock_retriever_class: + result = await pipeline.execute(ctx) + + mock_retriever_class.assert_called_once_with( + two_stage_enabled=True, + hybrid_enabled=True, + ) + assert result == mock_result + + +class TestStrategyRouter: + """[AC-AISVC-RES-01,02,03,07,08] Tests for StrategyRouter.""" + + @pytest.fixture + def router(self): + reset_strategy_router() + return StrategyRouter() + + def test_initial_state(self, router): + """[AC-AISVC-RES-01] Initial state should be default strategy.""" + assert router.current_strategy == StrategyType.DEFAULT + assert router.config.enabled is True + + def test_route_default_strategy(self, router): + """[AC-AISVC-RES-01] Should route to default strategy by default.""" + ctx = StrategyContext(tenant_id="tenant_a", query="Test query") + + result = router.route(ctx) + + assert result.strategy == StrategyType.DEFAULT + assert router.current_strategy == StrategyType.DEFAULT + + def test_route_enhanced_strategy_explicit(self, router): + """[AC-AISVC-RES-02] Should route to enhanced when explicitly configured.""" + router._config.strategy = StrategyType.ENHANCED + + ctx = StrategyContext(tenant_id="tenant_a", query="Test query") + + result = router.route(ctx) + + assert result.strategy == StrategyType.ENHANCED + assert router.current_strategy == StrategyType.ENHANCED + + def test_route_grayscale_allowlist(self, router): + """[AC-AISVC-RES-03] Should route to enhanced for tenants in allowlist.""" + router._config.strategy = StrategyType.ENHANCED + router._config.grayscale_allowlist = ["tenant_a", "tenant_b"] + + ctx_a = StrategyContext(tenant_id="tenant_a", query="Test query") + ctx_c = StrategyContext(tenant_id="tenant_c", query="Test query") + + result_a = router.route(ctx_a) + result_c = router.route(ctx_c) + + assert result_a.strategy == StrategyType.ENHANCED + assert result_c.strategy == StrategyType.ENHANCED + + def test_route_disabled_strategy(self, router): + """Should use default when strategy is disabled.""" + router._strategy_enabled = False + + ctx = StrategyContext(tenant_id="tenant_a", query="Test query") + + result = router.route(ctx) + + assert result.strategy == StrategyType.DEFAULT + + def test_update_config(self, router): + """[AC-AISVC-RES-15] Should update configuration.""" + new_config = RoutingConfig( + strategy=StrategyType.ENHANCED, + rag_runtime_mode=RagRuntimeMode.REACT, + ) + + router.update_config(new_config) + + assert router.config.strategy == StrategyType.ENHANCED + assert router.config.rag_runtime_mode == RagRuntimeMode.REACT + + def test_rollback(self, router): + """[AC-AISVC-RES-07] Should rollback to default strategy.""" + router._current_strategy = StrategyType.ENHANCED + router._config.strategy = StrategyType.ENHANCED + + router.rollback(reason="Test rollback", tenant_id="tenant_a") + + assert router.current_strategy == StrategyType.DEFAULT + assert router.config.strategy == StrategyType.DEFAULT + + records = router.get_rollback_records() + assert len(records) == 1 + assert records[0].reason == "Test rollback" + + def test_rollback_from_default_no_op(self, router): + """[AC-AISVC-RES-07] Rollback from default should be no-op.""" + router.rollback(reason="Test rollback") + + assert router.current_strategy == StrategyType.DEFAULT + records = router.get_rollback_records() + assert len(records) == 0 + + def test_validate_config(self, router): + """[AC-AISVC-RES-06] Should validate configuration.""" + is_valid, errors = router.validate_config() + + assert is_valid is True + assert len(errors) == 0 + + @pytest.mark.asyncio + async def test_execute_default_strategy(self, router): + """[AC-AISVC-RES-01] Should execute default pipeline.""" + ctx = StrategyContext(tenant_id="tenant_a", query="Test query") + + mock_result = MagicMock() + mock_result.hits = [] + + with patch.object( + router._default_pipeline, "execute", new_callable=AsyncMock + ) as mock_execute: + mock_execute.return_value = mock_result + + result, strategy_result = await router.execute(ctx) + + assert result == mock_result + assert strategy_result.strategy == StrategyType.DEFAULT + mock_execute.assert_called_once_with(ctx) + + @pytest.mark.asyncio + async def test_execute_enhanced_strategy(self, router): + """[AC-AISVC-RES-02] Should execute enhanced pipeline.""" + router._config.strategy = StrategyType.ENHANCED + ctx = StrategyContext(tenant_id="tenant_a", query="Test query") + + mock_result = MagicMock() + mock_result.hits = [] + + with patch.object( + router._enhanced_pipeline, "execute", new_callable=AsyncMock + ) as mock_execute: + mock_execute.return_value = mock_result + + result, strategy_result = await router.execute(ctx) + + assert result == mock_result + assert strategy_result.strategy == StrategyType.ENHANCED + mock_execute.assert_called_once_with(ctx) + + @pytest.mark.asyncio + async def test_execute_fallback_on_error(self, router): + """[AC-AISVC-RES-07] Should fallback to default on enhanced failure.""" + router._config.strategy = StrategyType.ENHANCED + ctx = StrategyContext(tenant_id="tenant_a", query="Test query") + + mock_result = MagicMock() + mock_result.hits = [] + + with patch.object( + router._enhanced_pipeline, "execute", new_callable=AsyncMock + ) as mock_enhanced: + mock_enhanced.side_effect = Exception("Enhanced pipeline failed") + + with patch.object( + router._default_pipeline, "execute", new_callable=AsyncMock + ) as mock_default: + mock_default.return_value = mock_result + + result, strategy_result = await router.execute(ctx) + + assert result == mock_result + assert strategy_result.strategy == StrategyType.DEFAULT + assert strategy_result.should_fallback is True + assert "Enhanced pipeline failed" in strategy_result.fallback_reason + + +class TestSingletonInstances: + """Tests for singleton instance getters.""" + + def test_get_strategy_router_singleton(self): + """Should return same router instance.""" + reset_strategy_router() + + router1 = get_strategy_router() + router2 = get_strategy_router() + + assert router1 is router2 + + def test_reset_strategy_router(self): + """Should create new instance after reset.""" + router1 = get_strategy_router() + reset_strategy_router() + router2 = get_strategy_router() + + assert router1 is not router2