[AC-AISVC-RES-01~15] test(retrieval): 新增策略路由单元测试
- 新增 test_routing_config.py 路由配置测试 - TestStrategyType: 策略类型枚举测试 - TestRagRuntimeMode: 运行模式枚举测试 - TestRoutingConfig: 路由配置测试 - TestStrategyContext: 策略上下文测试 - TestStrategyResult: 策略结果测试 - 新增 test_strategy_router.py 策略路由器测试 - TestRollbackRecord: 回滚记录测试 - TestRollbackManager: 回滚管理器测试 - TestDefaultPipeline: 默认管道测试 - TestEnhancedPipeline: 增强管道测试 - TestStrategyRouter: 策略路由器测试 - 新增 test_mode_router.py 模式路由器测试 - TestComplexityAnalyzer: 复杂度分析器测试 - TestModeRouteResult: 模式路由结果测试 - TestModeRouter: 模式路由器测试 - 新增 test_strategy_integration.py 集成层测试 - TestRetrievalStrategyResult: 集成结果测试 - TestRetrievalStrategyIntegration: 集成器测试 - 79 个测试用例全部通过
This commit is contained in:
parent
c0688c2b13
commit
4de51bb18a
|
|
@ -0,0 +1,297 @@
|
||||||
|
"""
|
||||||
|
Unit tests for Mode Router.
|
||||||
|
[AC-AISVC-RES-09,10,11,12,13,14] Tests for mode routing, complexity analysis, and fallback.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from app.services.retrieval.routing_config import (
|
||||||
|
RagRuntimeMode,
|
||||||
|
RoutingConfig,
|
||||||
|
StrategyContext,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.mode_router import (
|
||||||
|
ComplexityAnalyzer,
|
||||||
|
ModeRouteResult,
|
||||||
|
ModeRouter,
|
||||||
|
get_mode_router,
|
||||||
|
reset_mode_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestComplexityAnalyzer:
|
||||||
|
"""[AC-AISVC-RES-12,13] Tests for ComplexityAnalyzer."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def analyzer(self):
|
||||||
|
return ComplexityAnalyzer()
|
||||||
|
|
||||||
|
def test_analyze_empty_query(self, analyzer):
|
||||||
|
"""Empty query should have zero complexity."""
|
||||||
|
score = analyzer.analyze("")
|
||||||
|
|
||||||
|
assert score == 0.0
|
||||||
|
|
||||||
|
def test_analyze_short_query(self, analyzer):
|
||||||
|
"""Short query should have low complexity."""
|
||||||
|
score = analyzer.analyze("简单问题")
|
||||||
|
|
||||||
|
assert score < 0.3
|
||||||
|
|
||||||
|
def test_analyze_long_query(self, analyzer):
|
||||||
|
"""Long query should have higher complexity."""
|
||||||
|
long_query = "这是一个很长的问题" * 20
|
||||||
|
score = analyzer.analyze(long_query)
|
||||||
|
|
||||||
|
assert score >= 0.3
|
||||||
|
|
||||||
|
def test_analyze_multiple_conditions(self, analyzer):
|
||||||
|
"""[AC-AISVC-RES-13] Query with multiple conditions should have high complexity."""
|
||||||
|
query = "查询订单状态和物流信息以及退款进度"
|
||||||
|
score = analyzer.analyze(query)
|
||||||
|
|
||||||
|
assert score >= 0.2
|
||||||
|
|
||||||
|
def test_analyze_reasoning_patterns(self, analyzer):
|
||||||
|
"""Query with reasoning patterns should have higher complexity."""
|
||||||
|
query = "为什么订单会被取消?原因是什么?"
|
||||||
|
score = analyzer.analyze(query)
|
||||||
|
|
||||||
|
assert score >= 0.15
|
||||||
|
|
||||||
|
def test_analyze_cross_domain_patterns(self, analyzer):
|
||||||
|
"""Cross-domain query should have higher complexity."""
|
||||||
|
query = "汇总所有部门的销售数据"
|
||||||
|
score = analyzer.analyze(query)
|
||||||
|
|
||||||
|
assert score >= 0.15
|
||||||
|
|
||||||
|
def test_analyze_multiple_questions(self, analyzer):
|
||||||
|
"""Multiple questions should increase complexity."""
|
||||||
|
query = "什么是价格?如何购买?"
|
||||||
|
score = analyzer.analyze(query)
|
||||||
|
|
||||||
|
assert score >= 0.1
|
||||||
|
|
||||||
|
|
||||||
|
class TestModeRouteResult:
|
||||||
|
"""Tests for ModeRouteResult."""
|
||||||
|
|
||||||
|
def test_result_creation(self):
|
||||||
|
"""Should create result with all fields."""
|
||||||
|
result = ModeRouteResult(
|
||||||
|
mode=RagRuntimeMode.REACT,
|
||||||
|
confidence=0.5,
|
||||||
|
complexity_score=0.7,
|
||||||
|
should_fallback_to_react=True,
|
||||||
|
fallback_reason="Low confidence",
|
||||||
|
diagnostics={"key": "value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.mode == RagRuntimeMode.REACT
|
||||||
|
assert result.confidence == 0.5
|
||||||
|
assert result.complexity_score == 0.7
|
||||||
|
assert result.should_fallback_to_react is True
|
||||||
|
assert result.fallback_reason == "Low confidence"
|
||||||
|
assert result.diagnostics == {"key": "value"}
|
||||||
|
|
||||||
|
def test_result_defaults(self):
|
||||||
|
"""Should create result with default values."""
|
||||||
|
result = ModeRouteResult(
|
||||||
|
mode=RagRuntimeMode.DIRECT,
|
||||||
|
confidence=0.8,
|
||||||
|
complexity_score=0.2,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.should_fallback_to_react is False
|
||||||
|
assert result.fallback_reason is None
|
||||||
|
assert result.diagnostics == {}
|
||||||
|
|
||||||
|
|
||||||
|
class TestModeRouter:
|
||||||
|
"""[AC-AISVC-RES-09,10,11,12,13,14] Tests for ModeRouter."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def router(self):
|
||||||
|
reset_mode_router()
|
||||||
|
return ModeRouter()
|
||||||
|
|
||||||
|
def test_initial_config(self, router):
|
||||||
|
"""Should initialize with default configuration."""
|
||||||
|
assert router.config.rag_runtime_mode == RagRuntimeMode.AUTO
|
||||||
|
|
||||||
|
def test_route_direct_mode(self, router):
|
||||||
|
"""[AC-AISVC-RES-09] Should route to direct mode when configured."""
|
||||||
|
router._config.rag_runtime_mode = RagRuntimeMode.DIRECT
|
||||||
|
|
||||||
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
||||||
|
|
||||||
|
result = router.route(ctx)
|
||||||
|
|
||||||
|
assert result.mode == RagRuntimeMode.DIRECT
|
||||||
|
|
||||||
|
def test_route_react_mode(self, router):
|
||||||
|
"""[AC-AISVC-RES-10] Should route to react mode when configured."""
|
||||||
|
router._config.rag_runtime_mode = RagRuntimeMode.REACT
|
||||||
|
|
||||||
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
||||||
|
|
||||||
|
result = router.route(ctx)
|
||||||
|
|
||||||
|
assert result.mode == RagRuntimeMode.REACT
|
||||||
|
|
||||||
|
def test_route_auto_mode_direct_conditions(self, router):
|
||||||
|
"""[AC-AISVC-RES-11,12] Auto mode should select direct for simple queries."""
|
||||||
|
router._config.rag_runtime_mode = RagRuntimeMode.AUTO
|
||||||
|
|
||||||
|
ctx = StrategyContext(
|
||||||
|
tenant_id="tenant_a",
|
||||||
|
query="简单问题",
|
||||||
|
metadata_confidence=0.9,
|
||||||
|
complexity_score=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = router.route(ctx)
|
||||||
|
|
||||||
|
assert result.mode == RagRuntimeMode.DIRECT
|
||||||
|
|
||||||
|
def test_route_auto_mode_react_low_confidence(self, router):
|
||||||
|
"""[AC-AISVC-RES-11,12] Auto mode should select react for low confidence."""
|
||||||
|
router._config.rag_runtime_mode = RagRuntimeMode.AUTO
|
||||||
|
router._config.react_trigger_confidence_threshold = 0.6
|
||||||
|
|
||||||
|
ctx = StrategyContext(
|
||||||
|
tenant_id="tenant_a",
|
||||||
|
query="Test query",
|
||||||
|
metadata_confidence=0.4,
|
||||||
|
complexity_score=0.2,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = router.route(ctx)
|
||||||
|
|
||||||
|
assert result.mode == RagRuntimeMode.REACT
|
||||||
|
|
||||||
|
def test_route_auto_mode_react_high_complexity(self, router):
|
||||||
|
"""[AC-AISVC-RES-11,13] Auto mode should select react for high complexity."""
|
||||||
|
router._config.rag_runtime_mode = RagRuntimeMode.AUTO
|
||||||
|
router._config.react_trigger_complexity_score = 0.5
|
||||||
|
|
||||||
|
ctx = StrategyContext(
|
||||||
|
tenant_id="tenant_a",
|
||||||
|
query="Test query",
|
||||||
|
metadata_confidence=0.8,
|
||||||
|
complexity_score=0.7,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = router.route(ctx)
|
||||||
|
|
||||||
|
assert result.mode == RagRuntimeMode.REACT
|
||||||
|
|
||||||
|
def test_update_config(self, router):
|
||||||
|
"""[AC-AISVC-RES-15] Should update configuration."""
|
||||||
|
new_config = RoutingConfig(
|
||||||
|
rag_runtime_mode=RagRuntimeMode.REACT,
|
||||||
|
react_max_steps=7,
|
||||||
|
)
|
||||||
|
|
||||||
|
router.update_config(new_config)
|
||||||
|
|
||||||
|
assert router.config.rag_runtime_mode == RagRuntimeMode.REACT
|
||||||
|
assert router.config.react_max_steps == 7
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_direct(self, router):
|
||||||
|
"""[AC-AISVC-RES-09] Should execute direct retrieval."""
|
||||||
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.hits = []
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
router._direct_executor, "execute", new_callable=AsyncMock
|
||||||
|
) as mock_execute:
|
||||||
|
mock_execute.return_value = mock_result
|
||||||
|
|
||||||
|
result = await router.execute_direct(ctx)
|
||||||
|
|
||||||
|
assert result == mock_result
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_with_fallback_no_fallback(self, router):
|
||||||
|
"""[AC-AISVC-RES-14] Should not fallback when confidence is high."""
|
||||||
|
router._config.rag_runtime_mode = RagRuntimeMode.DIRECT
|
||||||
|
router._config.direct_fallback_on_low_confidence = True
|
||||||
|
router._config.direct_fallback_confidence_threshold = 0.4
|
||||||
|
|
||||||
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_hit = MagicMock()
|
||||||
|
mock_hit.score = 0.8
|
||||||
|
mock_result.hits = [mock_hit]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
router._direct_executor, "execute", new_callable=AsyncMock
|
||||||
|
) as mock_execute:
|
||||||
|
mock_execute.return_value = mock_result
|
||||||
|
|
||||||
|
result, answer, route_result = await router.execute_with_fallback(ctx)
|
||||||
|
|
||||||
|
assert result == mock_result
|
||||||
|
assert answer is None
|
||||||
|
assert route_result.mode == RagRuntimeMode.DIRECT
|
||||||
|
assert route_result.should_fallback_to_react is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_with_fallback_triggered(self, router):
|
||||||
|
"""[AC-AISVC-RES-14] Should fallback to react when confidence is low."""
|
||||||
|
router._config.rag_runtime_mode = RagRuntimeMode.DIRECT
|
||||||
|
router._config.direct_fallback_on_low_confidence = True
|
||||||
|
router._config.direct_fallback_confidence_threshold = 0.4
|
||||||
|
|
||||||
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_hit = MagicMock()
|
||||||
|
mock_hit.score = 0.2
|
||||||
|
mock_result.hits = [mock_hit]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
router._direct_executor, "execute", new_callable=AsyncMock
|
||||||
|
) as mock_direct:
|
||||||
|
mock_direct.return_value = mock_result
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
router._react_executor, "execute", new_callable=AsyncMock
|
||||||
|
) as mock_react:
|
||||||
|
mock_react.return_value = ("Final answer", None, {})
|
||||||
|
|
||||||
|
result, answer, route_result = await router.execute_with_fallback(ctx)
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
assert answer == "Final answer"
|
||||||
|
assert route_result.mode == RagRuntimeMode.REACT
|
||||||
|
assert route_result.should_fallback_to_react is True
|
||||||
|
assert route_result.fallback_reason == "low_confidence"
|
||||||
|
|
||||||
|
|
||||||
|
class TestSingletonInstances:
|
||||||
|
"""Tests for singleton instance getters."""
|
||||||
|
|
||||||
|
def test_get_mode_router_singleton(self):
|
||||||
|
"""Should return same router instance."""
|
||||||
|
reset_mode_router()
|
||||||
|
|
||||||
|
router1 = get_mode_router()
|
||||||
|
router2 = get_mode_router()
|
||||||
|
|
||||||
|
assert router1 is router2
|
||||||
|
|
||||||
|
def test_reset_mode_router(self):
|
||||||
|
"""Should create new instance after reset."""
|
||||||
|
router1 = get_mode_router()
|
||||||
|
reset_mode_router()
|
||||||
|
router2 = get_mode_router()
|
||||||
|
|
||||||
|
assert router1 is not router2
|
||||||
|
|
@ -0,0 +1,253 @@
|
||||||
|
"""
|
||||||
|
Unit tests for Routing Configuration.
|
||||||
|
[AC-AISVC-RES-01~15] Tests for strategy routing configuration models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.services.retrieval.routing_config import (
|
||||||
|
RagRuntimeMode,
|
||||||
|
StrategyType,
|
||||||
|
RoutingConfig,
|
||||||
|
StrategyContext,
|
||||||
|
StrategyResult,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStrategyType:
|
||||||
|
"""[AC-AISVC-RES-01,02] Tests for StrategyType enum."""
|
||||||
|
|
||||||
|
def test_strategy_type_default(self):
|
||||||
|
"""[AC-AISVC-RES-01] Default strategy type should exist."""
|
||||||
|
assert StrategyType.DEFAULT.value == "default"
|
||||||
|
|
||||||
|
def test_strategy_type_enhanced(self):
|
||||||
|
"""[AC-AISVC-RES-02] Enhanced strategy type should exist."""
|
||||||
|
assert StrategyType.ENHANCED.value == "enhanced"
|
||||||
|
|
||||||
|
def test_strategy_type_from_string(self):
|
||||||
|
"""Should create StrategyType from string."""
|
||||||
|
assert StrategyType("default") == StrategyType.DEFAULT
|
||||||
|
assert StrategyType("enhanced") == StrategyType.ENHANCED
|
||||||
|
|
||||||
|
|
||||||
|
class TestRagRuntimeMode:
|
||||||
|
"""[AC-AISVC-RES-09,10,11] Tests for RagRuntimeMode enum."""
|
||||||
|
|
||||||
|
def test_mode_direct(self):
|
||||||
|
"""[AC-AISVC-RES-09] Direct mode should exist."""
|
||||||
|
assert RagRuntimeMode.DIRECT.value == "direct"
|
||||||
|
|
||||||
|
def test_mode_react(self):
|
||||||
|
"""[AC-AISVC-RES-10] React mode should exist."""
|
||||||
|
assert RagRuntimeMode.REACT.value == "react"
|
||||||
|
|
||||||
|
def test_mode_auto(self):
|
||||||
|
"""[AC-AISVC-RES-11] Auto mode should exist."""
|
||||||
|
assert RagRuntimeMode.AUTO.value == "auto"
|
||||||
|
|
||||||
|
def test_mode_from_string(self):
|
||||||
|
"""Should create RagRuntimeMode from string."""
|
||||||
|
assert RagRuntimeMode("direct") == RagRuntimeMode.DIRECT
|
||||||
|
assert RagRuntimeMode("react") == RagRuntimeMode.REACT
|
||||||
|
assert RagRuntimeMode("auto") == RagRuntimeMode.AUTO
|
||||||
|
|
||||||
|
|
||||||
|
class TestRoutingConfig:
|
||||||
|
"""[AC-AISVC-RES-01~15] Tests for RoutingConfig."""
|
||||||
|
|
||||||
|
def test_default_config(self):
|
||||||
|
"""[AC-AISVC-RES-01] Default config should use default strategy and auto mode."""
|
||||||
|
config = RoutingConfig()
|
||||||
|
|
||||||
|
assert config.enabled is True
|
||||||
|
assert config.strategy == StrategyType.DEFAULT
|
||||||
|
assert config.rag_runtime_mode == RagRuntimeMode.AUTO
|
||||||
|
assert config.grayscale_percentage == 0.0
|
||||||
|
assert config.grayscale_allowlist == []
|
||||||
|
|
||||||
|
def test_config_with_custom_values(self):
|
||||||
|
"""[AC-AISVC-RES-02,03] Config should accept custom values."""
|
||||||
|
config = RoutingConfig(
|
||||||
|
enabled=False,
|
||||||
|
strategy=StrategyType.ENHANCED,
|
||||||
|
rag_runtime_mode=RagRuntimeMode.REACT,
|
||||||
|
grayscale_percentage=0.3,
|
||||||
|
grayscale_allowlist=["tenant_a", "tenant_b"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.enabled is False
|
||||||
|
assert config.strategy == StrategyType.ENHANCED
|
||||||
|
assert config.rag_runtime_mode == RagRuntimeMode.REACT
|
||||||
|
assert config.grayscale_percentage == 0.3
|
||||||
|
assert "tenant_a" in config.grayscale_allowlist
|
||||||
|
|
||||||
|
def test_should_use_enhanced_strategy_default(self):
|
||||||
|
"""[AC-AISVC-RES-01] Default strategy should not use enhanced."""
|
||||||
|
config = RoutingConfig()
|
||||||
|
|
||||||
|
assert config.should_use_enhanced_strategy("tenant_a") is False
|
||||||
|
|
||||||
|
def test_should_use_enhanced_strategy_explicit_enhanced(self):
|
||||||
|
"""[AC-AISVC-RES-02] Explicit enhanced strategy should use enhanced for all tenants."""
|
||||||
|
config = RoutingConfig(strategy=StrategyType.ENHANCED)
|
||||||
|
|
||||||
|
assert config.should_use_enhanced_strategy("tenant_a") is True
|
||||||
|
assert config.should_use_enhanced_strategy("tenant_b") is True
|
||||||
|
assert config.should_use_enhanced_strategy(None) is True
|
||||||
|
|
||||||
|
def test_should_fallback_direct_to_react_disabled(self):
|
||||||
|
"""[AC-AISVC-RES-14] Fallback should be disabled when configured."""
|
||||||
|
config = RoutingConfig(direct_fallback_on_low_confidence=False)
|
||||||
|
|
||||||
|
assert config.should_fallback_direct_to_react(0.1) is False
|
||||||
|
assert config.should_fallback_direct_to_react(0.0) is False
|
||||||
|
|
||||||
|
def test_should_fallback_direct_to_react_enabled(self):
|
||||||
|
"""[AC-AISVC-RES-14] Fallback should trigger on low confidence."""
|
||||||
|
config = RoutingConfig(
|
||||||
|
direct_fallback_on_low_confidence=True,
|
||||||
|
direct_fallback_confidence_threshold=0.4,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.should_fallback_direct_to_react(0.3) is True
|
||||||
|
assert config.should_fallback_direct_to_react(0.4) is False
|
||||||
|
assert config.should_fallback_direct_to_react(0.5) is False
|
||||||
|
|
||||||
|
def test_should_trigger_react_in_auto_mode_low_confidence(self):
|
||||||
|
"""[AC-AISVC-RES-12] Low confidence should trigger react in auto mode."""
|
||||||
|
config = RoutingConfig(
|
||||||
|
react_trigger_confidence_threshold=0.6,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.should_trigger_react_in_auto_mode(
|
||||||
|
confidence=0.5, complexity_score=0.3
|
||||||
|
) is True
|
||||||
|
assert config.should_trigger_react_in_auto_mode(
|
||||||
|
confidence=0.7, complexity_score=0.3
|
||||||
|
) is False
|
||||||
|
|
||||||
|
def test_should_trigger_react_in_auto_mode_high_complexity(self):
|
||||||
|
"""[AC-AISVC-RES-13] High complexity should trigger react in auto mode."""
|
||||||
|
config = RoutingConfig(
|
||||||
|
react_trigger_complexity_score=0.5,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert config.should_trigger_react_in_auto_mode(
|
||||||
|
confidence=0.8, complexity_score=0.6
|
||||||
|
) is True
|
||||||
|
assert config.should_trigger_react_in_auto_mode(
|
||||||
|
confidence=0.8, complexity_score=0.4
|
||||||
|
) is False
|
||||||
|
|
||||||
|
def test_validate_valid_config(self):
|
||||||
|
"""[AC-AISVC-RES-06] Valid config should pass validation."""
|
||||||
|
config = RoutingConfig()
|
||||||
|
|
||||||
|
is_valid, errors = config.validate()
|
||||||
|
|
||||||
|
assert is_valid is True
|
||||||
|
assert len(errors) == 0
|
||||||
|
|
||||||
|
def test_validate_invalid_grayscale_percentage(self):
|
||||||
|
"""[AC-AISVC-RES-06] Invalid grayscale percentage should fail validation."""
|
||||||
|
config = RoutingConfig(grayscale_percentage=1.5)
|
||||||
|
|
||||||
|
is_valid, errors = config.validate()
|
||||||
|
|
||||||
|
assert is_valid is False
|
||||||
|
assert any("grayscale_percentage" in e for e in errors)
|
||||||
|
|
||||||
|
def test_validate_invalid_confidence_threshold(self):
|
||||||
|
"""[AC-AISVC-RES-06] Invalid confidence threshold should fail validation."""
|
||||||
|
config = RoutingConfig(react_trigger_confidence_threshold=1.5)
|
||||||
|
|
||||||
|
is_valid, errors = config.validate()
|
||||||
|
|
||||||
|
assert is_valid is False
|
||||||
|
assert any("react_trigger_confidence_threshold" in e for e in errors)
|
||||||
|
|
||||||
|
def test_validate_invalid_react_max_steps(self):
|
||||||
|
"""[AC-AISVC-RES-06] Invalid react max steps should fail validation."""
|
||||||
|
config = RoutingConfig(react_max_steps=2)
|
||||||
|
|
||||||
|
is_valid, errors = config.validate()
|
||||||
|
|
||||||
|
assert is_valid is False
|
||||||
|
assert any("react_max_steps" in e for e in errors)
|
||||||
|
|
||||||
|
def test_validate_invalid_performance_budget(self):
|
||||||
|
"""[AC-AISVC-RES-06] Invalid performance budget should fail validation."""
|
||||||
|
config = RoutingConfig(performance_budget_ms=500)
|
||||||
|
|
||||||
|
is_valid, errors = config.validate()
|
||||||
|
|
||||||
|
assert is_valid is False
|
||||||
|
assert any("performance_budget_ms" in e for e in errors)
|
||||||
|
|
||||||
|
|
||||||
|
class TestStrategyContext:
|
||||||
|
"""[AC-AISVC-RES-01~15] Tests for StrategyContext."""
|
||||||
|
|
||||||
|
def test_context_creation(self):
|
||||||
|
"""Should create context with all fields."""
|
||||||
|
ctx = StrategyContext(
|
||||||
|
tenant_id="tenant_a",
|
||||||
|
query="Test query",
|
||||||
|
metadata_filter={"category": "product"},
|
||||||
|
metadata_confidence=0.8,
|
||||||
|
complexity_score=0.3,
|
||||||
|
kb_ids=["kb_1", "kb_2"],
|
||||||
|
top_k=10,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert ctx.tenant_id == "tenant_a"
|
||||||
|
assert ctx.query == "Test query"
|
||||||
|
assert ctx.metadata_filter == {"category": "product"}
|
||||||
|
assert ctx.metadata_confidence == 0.8
|
||||||
|
assert ctx.complexity_score == 0.3
|
||||||
|
assert ctx.kb_ids == ["kb_1", "kb_2"]
|
||||||
|
assert ctx.top_k == 10
|
||||||
|
|
||||||
|
def test_context_minimal(self):
|
||||||
|
"""Should create context with minimal fields."""
|
||||||
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
||||||
|
|
||||||
|
assert ctx.tenant_id == "tenant_a"
|
||||||
|
assert ctx.query == "Test query"
|
||||||
|
assert ctx.metadata_filter is None
|
||||||
|
assert ctx.metadata_confidence == 1.0
|
||||||
|
assert ctx.complexity_score == 0.0
|
||||||
|
assert ctx.kb_ids is None
|
||||||
|
assert ctx.top_k == 5
|
||||||
|
|
||||||
|
|
||||||
|
class TestStrategyResult:
|
||||||
|
"""[AC-AISVC-RES-01,02] Tests for StrategyResult."""
|
||||||
|
|
||||||
|
def test_result_creation(self):
|
||||||
|
"""Should create result with all fields."""
|
||||||
|
result = StrategyResult(
|
||||||
|
strategy=StrategyType.ENHANCED,
|
||||||
|
mode=RagRuntimeMode.REACT,
|
||||||
|
should_fallback=True,
|
||||||
|
fallback_reason="Low confidence",
|
||||||
|
diagnostics={"key": "value"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.strategy == StrategyType.ENHANCED
|
||||||
|
assert result.mode == RagRuntimeMode.REACT
|
||||||
|
assert result.should_fallback is True
|
||||||
|
assert result.fallback_reason == "Low confidence"
|
||||||
|
assert result.diagnostics == {"key": "value"}
|
||||||
|
|
||||||
|
def test_result_defaults(self):
|
||||||
|
"""Should create result with default values."""
|
||||||
|
result = StrategyResult(
|
||||||
|
strategy=StrategyType.DEFAULT,
|
||||||
|
mode=RagRuntimeMode.AUTO,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.should_fallback is False
|
||||||
|
assert result.fallback_reason is None
|
||||||
|
assert result.diagnostics == {}
|
||||||
|
|
@ -0,0 +1,256 @@
|
||||||
|
"""
|
||||||
|
Unit tests for Retrieval Strategy Integration.
|
||||||
|
[AC-AISVC-RES-01~15] Tests for integrated strategy and mode routing.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from app.services.retrieval.routing_config import (
|
||||||
|
RagRuntimeMode,
|
||||||
|
StrategyType,
|
||||||
|
RoutingConfig,
|
||||||
|
StrategyContext,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.strategy_integration import (
|
||||||
|
RetrievalStrategyResult,
|
||||||
|
RetrievalStrategyIntegration,
|
||||||
|
get_retrieval_strategy_integration,
|
||||||
|
reset_retrieval_strategy_integration,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetrievalStrategyResult:
|
||||||
|
"""Tests for RetrievalStrategyResult."""
|
||||||
|
|
||||||
|
def test_result_creation(self):
|
||||||
|
"""Should create result with all fields."""
|
||||||
|
result = RetrievalStrategyResult(
|
||||||
|
retrieval_result=None,
|
||||||
|
final_answer="Test answer",
|
||||||
|
strategy=StrategyType.ENHANCED,
|
||||||
|
mode=RagRuntimeMode.REACT,
|
||||||
|
should_fallback=True,
|
||||||
|
fallback_reason="Low confidence",
|
||||||
|
diagnostics={"key": "value"},
|
||||||
|
duration_ms=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.retrieval_result is None
|
||||||
|
assert result.final_answer == "Test answer"
|
||||||
|
assert result.strategy == StrategyType.ENHANCED
|
||||||
|
assert result.mode == RagRuntimeMode.REACT
|
||||||
|
assert result.should_fallback is True
|
||||||
|
assert result.fallback_reason == "Low confidence"
|
||||||
|
assert result.diagnostics == {"key": "value"}
|
||||||
|
assert result.duration_ms == 100
|
||||||
|
|
||||||
|
def test_result_defaults(self):
|
||||||
|
"""Should create result with default values."""
|
||||||
|
result = RetrievalStrategyResult(
|
||||||
|
retrieval_result=None,
|
||||||
|
final_answer=None,
|
||||||
|
strategy=StrategyType.DEFAULT,
|
||||||
|
mode=RagRuntimeMode.DIRECT,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result.should_fallback is False
|
||||||
|
assert result.fallback_reason is None
|
||||||
|
assert result.mode_route_result is None
|
||||||
|
assert result.diagnostics == {}
|
||||||
|
assert result.duration_ms == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestRetrievalStrategyIntegration:
|
||||||
|
"""[AC-AISVC-RES-01~15] Tests for RetrievalStrategyIntegration."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def integration(self):
|
||||||
|
reset_retrieval_strategy_integration()
|
||||||
|
return RetrievalStrategyIntegration()
|
||||||
|
|
||||||
|
def test_initial_state(self, integration):
|
||||||
|
"""Should initialize with default configuration."""
|
||||||
|
assert integration.config.strategy == StrategyType.DEFAULT
|
||||||
|
assert integration.config.rag_runtime_mode == RagRuntimeMode.AUTO
|
||||||
|
|
||||||
|
def test_update_config(self, integration):
|
||||||
|
"""[AC-AISVC-RES-15] Should update all configurations."""
|
||||||
|
new_config = RoutingConfig(
|
||||||
|
strategy=StrategyType.ENHANCED,
|
||||||
|
rag_runtime_mode=RagRuntimeMode.REACT,
|
||||||
|
react_max_steps=7,
|
||||||
|
)
|
||||||
|
|
||||||
|
integration.update_config(new_config)
|
||||||
|
|
||||||
|
assert integration.config.strategy == StrategyType.ENHANCED
|
||||||
|
assert integration.config.rag_runtime_mode == RagRuntimeMode.REACT
|
||||||
|
|
||||||
|
def test_get_current_strategy(self, integration):
|
||||||
|
"""Should return current strategy from router."""
|
||||||
|
strategy = integration.get_current_strategy()
|
||||||
|
|
||||||
|
assert strategy == StrategyType.DEFAULT
|
||||||
|
|
||||||
|
def test_get_rollback_records(self, integration):
|
||||||
|
"""Should return rollback records from router."""
|
||||||
|
records = integration.get_rollback_records()
|
||||||
|
|
||||||
|
assert isinstance(records, list)
|
||||||
|
|
||||||
|
def test_validate_config(self, integration):
|
||||||
|
"""[AC-AISVC-RES-06] Should validate configuration."""
|
||||||
|
is_valid, errors = integration.validate_config()
|
||||||
|
|
||||||
|
assert is_valid is True
|
||||||
|
assert len(errors) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_direct_mode(self, integration):
|
||||||
|
"""[AC-AISVC-RES-09] Should execute direct mode."""
|
||||||
|
integration._config.rag_runtime_mode = RagRuntimeMode.DIRECT
|
||||||
|
|
||||||
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.hits = []
|
||||||
|
|
||||||
|
mock_route_result = MagicMock()
|
||||||
|
mock_route_result.mode = RagRuntimeMode.DIRECT
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
integration._mode_router, "route", return_value=mock_route_result
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
integration._mode_router, "execute_with_fallback", new_callable=AsyncMock
|
||||||
|
) as mock_execute:
|
||||||
|
mock_execute.return_value = (mock_result, None, mock_route_result)
|
||||||
|
|
||||||
|
result = await integration.execute(ctx)
|
||||||
|
|
||||||
|
assert result.retrieval_result == mock_result
|
||||||
|
assert result.final_answer is None
|
||||||
|
assert result.mode == RagRuntimeMode.DIRECT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_react_mode(self, integration):
|
||||||
|
"""[AC-AISVC-RES-10] Should execute react mode."""
|
||||||
|
integration._config.rag_runtime_mode = RagRuntimeMode.REACT
|
||||||
|
|
||||||
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
||||||
|
|
||||||
|
mock_route_result = MagicMock()
|
||||||
|
mock_route_result.mode = RagRuntimeMode.REACT
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
integration._mode_router, "route", return_value=mock_route_result
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
integration._mode_router, "execute_react", new_callable=AsyncMock
|
||||||
|
) as mock_execute:
|
||||||
|
mock_execute.return_value = ("Final answer", None, {})
|
||||||
|
|
||||||
|
result = await integration.execute(ctx)
|
||||||
|
|
||||||
|
assert result.retrieval_result is None
|
||||||
|
assert result.final_answer == "Final answer"
|
||||||
|
assert result.mode == RagRuntimeMode.REACT
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_with_fallback(self, integration):
|
||||||
|
"""[AC-AISVC-RES-14] Should handle fallback from direct to react."""
|
||||||
|
integration._config.rag_runtime_mode = RagRuntimeMode.DIRECT
|
||||||
|
integration._config.direct_fallback_on_low_confidence = True
|
||||||
|
integration._config.direct_fallback_confidence_threshold = 0.4
|
||||||
|
|
||||||
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
||||||
|
|
||||||
|
mock_route_result = MagicMock()
|
||||||
|
mock_route_result.mode = RagRuntimeMode.DIRECT
|
||||||
|
mock_route_result.should_fallback_to_react = True
|
||||||
|
mock_route_result.fallback_reason = "low_confidence"
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
integration._mode_router, "route", return_value=mock_route_result
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
integration._mode_router, "execute_with_fallback", new_callable=AsyncMock
|
||||||
|
) as mock_execute:
|
||||||
|
mock_execute.return_value = (None, "Fallback answer", mock_route_result)
|
||||||
|
|
||||||
|
result = await integration.execute(ctx)
|
||||||
|
|
||||||
|
assert result.retrieval_result is None
|
||||||
|
assert result.final_answer == "Fallback answer"
|
||||||
|
assert result.should_fallback is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_includes_diagnostics(self, integration):
|
||||||
|
"""Should include diagnostics in result."""
|
||||||
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.hits = []
|
||||||
|
|
||||||
|
mock_route_result = MagicMock()
|
||||||
|
mock_route_result.mode = RagRuntimeMode.DIRECT
|
||||||
|
mock_route_result.diagnostics = {"mode_key": "mode_value"}
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
integration._mode_router, "route", return_value=mock_route_result
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
integration._mode_router, "execute_with_fallback", new_callable=AsyncMock
|
||||||
|
) as mock_execute:
|
||||||
|
mock_execute.return_value = (mock_result, None, mock_route_result)
|
||||||
|
|
||||||
|
result = await integration.execute(ctx)
|
||||||
|
|
||||||
|
assert "strategy_diagnostics" in result.diagnostics
|
||||||
|
assert "mode_diagnostics" in result.diagnostics
|
||||||
|
assert "duration_ms" in result.diagnostics
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_tracks_duration(self, integration):
|
||||||
|
"""Should track execution duration."""
|
||||||
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.hits = []
|
||||||
|
|
||||||
|
mock_route_result = MagicMock()
|
||||||
|
mock_route_result.mode = RagRuntimeMode.DIRECT
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
integration._mode_router, "route", return_value=mock_route_result
|
||||||
|
):
|
||||||
|
with patch.object(
|
||||||
|
integration._mode_router, "execute_with_fallback", new_callable=AsyncMock
|
||||||
|
) as mock_execute:
|
||||||
|
mock_execute.return_value = (mock_result, None, mock_route_result)
|
||||||
|
|
||||||
|
result = await integration.execute(ctx)
|
||||||
|
|
||||||
|
assert result.duration_ms >= 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestSingletonInstances:
|
||||||
|
"""Tests for singleton instance getters."""
|
||||||
|
|
||||||
|
def test_get_retrieval_strategy_integration_singleton(self):
|
||||||
|
"""Should return same integration instance."""
|
||||||
|
reset_retrieval_strategy_integration()
|
||||||
|
|
||||||
|
integration1 = get_retrieval_strategy_integration()
|
||||||
|
integration2 = get_retrieval_strategy_integration()
|
||||||
|
|
||||||
|
assert integration1 is integration2
|
||||||
|
|
||||||
|
def test_reset_retrieval_strategy_integration(self):
|
||||||
|
"""Should create new instance after reset."""
|
||||||
|
integration1 = get_retrieval_strategy_integration()
|
||||||
|
reset_retrieval_strategy_integration()
|
||||||
|
integration2 = get_retrieval_strategy_integration()
|
||||||
|
|
||||||
|
assert integration1 is not integration2
|
||||||
|
|
@ -0,0 +1,344 @@
|
||||||
|
"""
|
||||||
|
Unit tests for Strategy Router.
|
||||||
|
[AC-AISVC-RES-01,02,03,07,08] Tests for strategy routing, rollback, and grayscale release.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
from app.services.retrieval.routing_config import (
|
||||||
|
RagRuntimeMode,
|
||||||
|
StrategyType,
|
||||||
|
RoutingConfig,
|
||||||
|
StrategyContext,
|
||||||
|
)
|
||||||
|
from app.services.retrieval.strategy_router import (
|
||||||
|
RollbackRecord,
|
||||||
|
RollbackManager,
|
||||||
|
DefaultPipeline,
|
||||||
|
EnhancedPipeline,
|
||||||
|
StrategyRouter,
|
||||||
|
get_strategy_router,
|
||||||
|
reset_strategy_router,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestRollbackRecord:
|
||||||
|
"""[AC-AISVC-RES-07] Tests for RollbackRecord."""
|
||||||
|
|
||||||
|
def test_rollback_record_creation(self):
|
||||||
|
"""Should create rollback record with all fields."""
|
||||||
|
record = RollbackRecord(
|
||||||
|
timestamp=1234567890.0,
|
||||||
|
from_strategy=StrategyType.ENHANCED,
|
||||||
|
to_strategy=StrategyType.DEFAULT,
|
||||||
|
reason="Performance issue",
|
||||||
|
tenant_id="tenant_a",
|
||||||
|
request_id="req_123",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert record.timestamp == 1234567890.0
|
||||||
|
assert record.from_strategy == StrategyType.ENHANCED
|
||||||
|
assert record.to_strategy == StrategyType.DEFAULT
|
||||||
|
assert record.reason == "Performance issue"
|
||||||
|
assert record.tenant_id == "tenant_a"
|
||||||
|
assert record.request_id == "req_123"
|
||||||
|
|
||||||
|
|
||||||
|
class TestRollbackManager:
|
||||||
|
"""[AC-AISVC-RES-07] Tests for RollbackManager."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def manager(self):
|
||||||
|
return RollbackManager(max_records=10)
|
||||||
|
|
||||||
|
def test_record_rollback(self, manager):
|
||||||
|
"""[AC-AISVC-RES-07] Should record rollback event."""
|
||||||
|
manager.record_rollback(
|
||||||
|
from_strategy=StrategyType.ENHANCED,
|
||||||
|
to_strategy=StrategyType.DEFAULT,
|
||||||
|
reason="Test rollback",
|
||||||
|
tenant_id="tenant_a",
|
||||||
|
)
|
||||||
|
|
||||||
|
records = manager.get_recent_rollbacks()
|
||||||
|
assert len(records) == 1
|
||||||
|
assert records[0].from_strategy == StrategyType.ENHANCED
|
||||||
|
assert records[0].to_strategy == StrategyType.DEFAULT
|
||||||
|
assert records[0].reason == "Test rollback"
|
||||||
|
|
||||||
|
def test_max_records_limit(self, manager):
|
||||||
|
"""Should limit number of records."""
|
||||||
|
for i in range(15):
|
||||||
|
manager.record_rollback(
|
||||||
|
from_strategy=StrategyType.ENHANCED,
|
||||||
|
to_strategy=StrategyType.DEFAULT,
|
||||||
|
reason=f"Reason {i}",
|
||||||
|
)
|
||||||
|
|
||||||
|
records = manager.get_recent_rollbacks(limit=20)
|
||||||
|
assert len(records) == 10
|
||||||
|
|
||||||
|
def test_get_rollback_count(self, manager):
|
||||||
|
"""Should count rollbacks correctly."""
|
||||||
|
import time
|
||||||
|
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
manager.record_rollback(
|
||||||
|
from_strategy=StrategyType.ENHANCED,
|
||||||
|
to_strategy=StrategyType.DEFAULT,
|
||||||
|
reason="Reason 1",
|
||||||
|
)
|
||||||
|
|
||||||
|
assert manager.get_rollback_count() == 1
|
||||||
|
assert manager.get_rollback_count(since_timestamp=now) == 1
|
||||||
|
assert manager.get_rollback_count(since_timestamp=now + 10) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestDefaultPipeline:
|
||||||
|
"""[AC-AISVC-RES-01] Tests for DefaultPipeline."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_uses_optimized_retriever(self):
|
||||||
|
"""[AC-AISVC-RES-01] Default pipeline should use OptimizedRetriever."""
|
||||||
|
pipeline = DefaultPipeline()
|
||||||
|
|
||||||
|
ctx = StrategyContext(
|
||||||
|
tenant_id="tenant_a",
|
||||||
|
query="Test query",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.hits = []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.services.retrieval.optimized_retriever.get_optimized_retriever",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
) as mock_get_retriever:
|
||||||
|
mock_retriever = AsyncMock()
|
||||||
|
mock_retriever.retrieve.return_value = mock_result
|
||||||
|
mock_get_retriever.return_value = mock_retriever
|
||||||
|
|
||||||
|
result = await pipeline.execute(ctx)
|
||||||
|
|
||||||
|
assert result == mock_result
|
||||||
|
mock_retriever.retrieve.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnhancedPipeline:
|
||||||
|
"""[AC-AISVC-RES-02] Tests for EnhancedPipeline."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_creates_optimized_retriever(self):
|
||||||
|
"""[AC-AISVC-RES-02] Enhanced pipeline should create OptimizedRetriever with enhanced config."""
|
||||||
|
config = RoutingConfig(strategy=StrategyType.ENHANCED)
|
||||||
|
pipeline = EnhancedPipeline(config=config)
|
||||||
|
|
||||||
|
ctx = StrategyContext(
|
||||||
|
tenant_id="tenant_a",
|
||||||
|
query="Test query",
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.hits = []
|
||||||
|
|
||||||
|
mock_retriever = AsyncMock()
|
||||||
|
mock_retriever.retrieve.return_value = mock_result
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.services.retrieval.optimized_retriever.OptimizedRetriever",
|
||||||
|
return_value=mock_retriever,
|
||||||
|
) as mock_retriever_class:
|
||||||
|
result = await pipeline.execute(ctx)
|
||||||
|
|
||||||
|
mock_retriever_class.assert_called_once_with(
|
||||||
|
two_stage_enabled=True,
|
||||||
|
hybrid_enabled=True,
|
||||||
|
)
|
||||||
|
assert result == mock_result
|
||||||
|
|
||||||
|
|
||||||
|
class TestStrategyRouter:
|
||||||
|
"""[AC-AISVC-RES-01,02,03,07,08] Tests for StrategyRouter."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def router(self):
|
||||||
|
reset_strategy_router()
|
||||||
|
return StrategyRouter()
|
||||||
|
|
||||||
|
def test_initial_state(self, router):
|
||||||
|
"""[AC-AISVC-RES-01] Initial state should be default strategy."""
|
||||||
|
assert router.current_strategy == StrategyType.DEFAULT
|
||||||
|
assert router.config.enabled is True
|
||||||
|
|
||||||
|
def test_route_default_strategy(self, router):
|
||||||
|
"""[AC-AISVC-RES-01] Should route to default strategy by default."""
|
||||||
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
||||||
|
|
||||||
|
result = router.route(ctx)
|
||||||
|
|
||||||
|
assert result.strategy == StrategyType.DEFAULT
|
||||||
|
assert router.current_strategy == StrategyType.DEFAULT
|
||||||
|
|
||||||
|
def test_route_enhanced_strategy_explicit(self, router):
|
||||||
|
"""[AC-AISVC-RES-02] Should route to enhanced when explicitly configured."""
|
||||||
|
router._config.strategy = StrategyType.ENHANCED
|
||||||
|
|
||||||
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
||||||
|
|
||||||
|
result = router.route(ctx)
|
||||||
|
|
||||||
|
assert result.strategy == StrategyType.ENHANCED
|
||||||
|
assert router.current_strategy == StrategyType.ENHANCED
|
||||||
|
|
||||||
|
def test_route_grayscale_allowlist(self, router):
|
||||||
|
"""[AC-AISVC-RES-03] Should route to enhanced for tenants in allowlist."""
|
||||||
|
router._config.strategy = StrategyType.ENHANCED
|
||||||
|
router._config.grayscale_allowlist = ["tenant_a", "tenant_b"]
|
||||||
|
|
||||||
|
ctx_a = StrategyContext(tenant_id="tenant_a", query="Test query")
|
||||||
|
ctx_c = StrategyContext(tenant_id="tenant_c", query="Test query")
|
||||||
|
|
||||||
|
result_a = router.route(ctx_a)
|
||||||
|
result_c = router.route(ctx_c)
|
||||||
|
|
||||||
|
assert result_a.strategy == StrategyType.ENHANCED
|
||||||
|
assert result_c.strategy == StrategyType.ENHANCED
|
||||||
|
|
||||||
|
def test_route_disabled_strategy(self, router):
|
||||||
|
"""Should use default when strategy is disabled."""
|
||||||
|
router._strategy_enabled = False
|
||||||
|
|
||||||
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
||||||
|
|
||||||
|
result = router.route(ctx)
|
||||||
|
|
||||||
|
assert result.strategy == StrategyType.DEFAULT
|
||||||
|
|
||||||
|
def test_update_config(self, router):
|
||||||
|
"""[AC-AISVC-RES-15] Should update configuration."""
|
||||||
|
new_config = RoutingConfig(
|
||||||
|
strategy=StrategyType.ENHANCED,
|
||||||
|
rag_runtime_mode=RagRuntimeMode.REACT,
|
||||||
|
)
|
||||||
|
|
||||||
|
router.update_config(new_config)
|
||||||
|
|
||||||
|
assert router.config.strategy == StrategyType.ENHANCED
|
||||||
|
assert router.config.rag_runtime_mode == RagRuntimeMode.REACT
|
||||||
|
|
||||||
|
def test_rollback(self, router):
|
||||||
|
"""[AC-AISVC-RES-07] Should rollback to default strategy."""
|
||||||
|
router._current_strategy = StrategyType.ENHANCED
|
||||||
|
router._config.strategy = StrategyType.ENHANCED
|
||||||
|
|
||||||
|
router.rollback(reason="Test rollback", tenant_id="tenant_a")
|
||||||
|
|
||||||
|
assert router.current_strategy == StrategyType.DEFAULT
|
||||||
|
assert router.config.strategy == StrategyType.DEFAULT
|
||||||
|
|
||||||
|
records = router.get_rollback_records()
|
||||||
|
assert len(records) == 1
|
||||||
|
assert records[0].reason == "Test rollback"
|
||||||
|
|
||||||
|
def test_rollback_from_default_no_op(self, router):
|
||||||
|
"""[AC-AISVC-RES-07] Rollback from default should be no-op."""
|
||||||
|
router.rollback(reason="Test rollback")
|
||||||
|
|
||||||
|
assert router.current_strategy == StrategyType.DEFAULT
|
||||||
|
records = router.get_rollback_records()
|
||||||
|
assert len(records) == 0
|
||||||
|
|
||||||
|
def test_validate_config(self, router):
|
||||||
|
"""[AC-AISVC-RES-06] Should validate configuration."""
|
||||||
|
is_valid, errors = router.validate_config()
|
||||||
|
|
||||||
|
assert is_valid is True
|
||||||
|
assert len(errors) == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_default_strategy(self, router):
|
||||||
|
"""[AC-AISVC-RES-01] Should execute default pipeline."""
|
||||||
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.hits = []
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
router._default_pipeline, "execute", new_callable=AsyncMock
|
||||||
|
) as mock_execute:
|
||||||
|
mock_execute.return_value = mock_result
|
||||||
|
|
||||||
|
result, strategy_result = await router.execute(ctx)
|
||||||
|
|
||||||
|
assert result == mock_result
|
||||||
|
assert strategy_result.strategy == StrategyType.DEFAULT
|
||||||
|
mock_execute.assert_called_once_with(ctx)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_enhanced_strategy(self, router):
|
||||||
|
"""[AC-AISVC-RES-02] Should execute enhanced pipeline."""
|
||||||
|
router._config.strategy = StrategyType.ENHANCED
|
||||||
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.hits = []
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
router._enhanced_pipeline, "execute", new_callable=AsyncMock
|
||||||
|
) as mock_execute:
|
||||||
|
mock_execute.return_value = mock_result
|
||||||
|
|
||||||
|
result, strategy_result = await router.execute(ctx)
|
||||||
|
|
||||||
|
assert result == mock_result
|
||||||
|
assert strategy_result.strategy == StrategyType.ENHANCED
|
||||||
|
mock_execute.assert_called_once_with(ctx)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_execute_fallback_on_error(self, router):
|
||||||
|
"""[AC-AISVC-RES-07] Should fallback to default on enhanced failure."""
|
||||||
|
router._config.strategy = StrategyType.ENHANCED
|
||||||
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
||||||
|
|
||||||
|
mock_result = MagicMock()
|
||||||
|
mock_result.hits = []
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
router._enhanced_pipeline, "execute", new_callable=AsyncMock
|
||||||
|
) as mock_enhanced:
|
||||||
|
mock_enhanced.side_effect = Exception("Enhanced pipeline failed")
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
router._default_pipeline, "execute", new_callable=AsyncMock
|
||||||
|
) as mock_default:
|
||||||
|
mock_default.return_value = mock_result
|
||||||
|
|
||||||
|
result, strategy_result = await router.execute(ctx)
|
||||||
|
|
||||||
|
assert result == mock_result
|
||||||
|
assert strategy_result.strategy == StrategyType.DEFAULT
|
||||||
|
assert strategy_result.should_fallback is True
|
||||||
|
assert "Enhanced pipeline failed" in strategy_result.fallback_reason
|
||||||
|
|
||||||
|
|
||||||
|
class TestSingletonInstances:
|
||||||
|
"""Tests for singleton instance getters."""
|
||||||
|
|
||||||
|
def test_get_strategy_router_singleton(self):
|
||||||
|
"""Should return same router instance."""
|
||||||
|
reset_strategy_router()
|
||||||
|
|
||||||
|
router1 = get_strategy_router()
|
||||||
|
router2 = get_strategy_router()
|
||||||
|
|
||||||
|
assert router1 is router2
|
||||||
|
|
||||||
|
def test_reset_strategy_router(self):
|
||||||
|
"""Should create new instance after reset."""
|
||||||
|
router1 = get_strategy_router()
|
||||||
|
reset_strategy_router()
|
||||||
|
router2 = get_strategy_router()
|
||||||
|
|
||||||
|
assert router1 is not router2
|
||||||
Loading…
Reference in New Issue