345 lines
12 KiB
Python
345 lines
12 KiB
Python
"""
|
|
Unit tests for Strategy Router.
|
|
[AC-AISVC-RES-01,02,03,07,08] Tests for strategy routing, rollback, and grayscale release.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from app.services.retrieval.routing_config import (
|
|
RagRuntimeMode,
|
|
StrategyType,
|
|
RoutingConfig,
|
|
StrategyContext,
|
|
)
|
|
from app.services.retrieval.strategy_router import (
|
|
RollbackRecord,
|
|
RollbackManager,
|
|
DefaultPipeline,
|
|
EnhancedPipeline,
|
|
StrategyRouter,
|
|
get_strategy_router,
|
|
reset_strategy_router,
|
|
)
|
|
|
|
|
|
class TestRollbackRecord:
|
|
"""[AC-AISVC-RES-07] Tests for RollbackRecord."""
|
|
|
|
def test_rollback_record_creation(self):
|
|
"""Should create rollback record with all fields."""
|
|
record = RollbackRecord(
|
|
timestamp=1234567890.0,
|
|
from_strategy=StrategyType.ENHANCED,
|
|
to_strategy=StrategyType.DEFAULT,
|
|
reason="Performance issue",
|
|
tenant_id="tenant_a",
|
|
request_id="req_123",
|
|
)
|
|
|
|
assert record.timestamp == 1234567890.0
|
|
assert record.from_strategy == StrategyType.ENHANCED
|
|
assert record.to_strategy == StrategyType.DEFAULT
|
|
assert record.reason == "Performance issue"
|
|
assert record.tenant_id == "tenant_a"
|
|
assert record.request_id == "req_123"
|
|
|
|
|
|
class TestRollbackManager:
|
|
"""[AC-AISVC-RES-07] Tests for RollbackManager."""
|
|
|
|
@pytest.fixture
|
|
def manager(self):
|
|
return RollbackManager(max_records=10)
|
|
|
|
def test_record_rollback(self, manager):
|
|
"""[AC-AISVC-RES-07] Should record rollback event."""
|
|
manager.record_rollback(
|
|
from_strategy=StrategyType.ENHANCED,
|
|
to_strategy=StrategyType.DEFAULT,
|
|
reason="Test rollback",
|
|
tenant_id="tenant_a",
|
|
)
|
|
|
|
records = manager.get_recent_rollbacks()
|
|
assert len(records) == 1
|
|
assert records[0].from_strategy == StrategyType.ENHANCED
|
|
assert records[0].to_strategy == StrategyType.DEFAULT
|
|
assert records[0].reason == "Test rollback"
|
|
|
|
def test_max_records_limit(self, manager):
|
|
"""Should limit number of records."""
|
|
for i in range(15):
|
|
manager.record_rollback(
|
|
from_strategy=StrategyType.ENHANCED,
|
|
to_strategy=StrategyType.DEFAULT,
|
|
reason=f"Reason {i}",
|
|
)
|
|
|
|
records = manager.get_recent_rollbacks(limit=20)
|
|
assert len(records) == 10
|
|
|
|
def test_get_rollback_count(self, manager):
|
|
"""Should count rollbacks correctly."""
|
|
import time
|
|
|
|
now = time.time()
|
|
|
|
manager.record_rollback(
|
|
from_strategy=StrategyType.ENHANCED,
|
|
to_strategy=StrategyType.DEFAULT,
|
|
reason="Reason 1",
|
|
)
|
|
|
|
assert manager.get_rollback_count() == 1
|
|
assert manager.get_rollback_count(since_timestamp=now) == 1
|
|
assert manager.get_rollback_count(since_timestamp=now + 10) == 0
|
|
|
|
|
|
class TestDefaultPipeline:
|
|
"""[AC-AISVC-RES-01] Tests for DefaultPipeline."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_uses_optimized_retriever(self):
|
|
"""[AC-AISVC-RES-01] Default pipeline should use OptimizedRetriever."""
|
|
pipeline = DefaultPipeline()
|
|
|
|
ctx = StrategyContext(
|
|
tenant_id="tenant_a",
|
|
query="Test query",
|
|
)
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.hits = []
|
|
|
|
with patch(
|
|
"app.services.retrieval.optimized_retriever.get_optimized_retriever",
|
|
new_callable=AsyncMock,
|
|
) as mock_get_retriever:
|
|
mock_retriever = AsyncMock()
|
|
mock_retriever.retrieve.return_value = mock_result
|
|
mock_get_retriever.return_value = mock_retriever
|
|
|
|
result = await pipeline.execute(ctx)
|
|
|
|
assert result == mock_result
|
|
mock_retriever.retrieve.assert_called_once()
|
|
|
|
|
|
class TestEnhancedPipeline:
|
|
"""[AC-AISVC-RES-02] Tests for EnhancedPipeline."""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_creates_optimized_retriever(self):
|
|
"""[AC-AISVC-RES-02] Enhanced pipeline should create OptimizedRetriever with enhanced config."""
|
|
config = RoutingConfig(strategy=StrategyType.ENHANCED)
|
|
pipeline = EnhancedPipeline(config=config)
|
|
|
|
ctx = StrategyContext(
|
|
tenant_id="tenant_a",
|
|
query="Test query",
|
|
)
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.hits = []
|
|
|
|
mock_retriever = AsyncMock()
|
|
mock_retriever.retrieve.return_value = mock_result
|
|
|
|
with patch(
|
|
"app.services.retrieval.optimized_retriever.OptimizedRetriever",
|
|
return_value=mock_retriever,
|
|
) as mock_retriever_class:
|
|
result = await pipeline.execute(ctx)
|
|
|
|
mock_retriever_class.assert_called_once_with(
|
|
two_stage_enabled=True,
|
|
hybrid_enabled=True,
|
|
)
|
|
assert result == mock_result
|
|
|
|
|
|
class TestStrategyRouter:
|
|
"""[AC-AISVC-RES-01,02,03,07,08] Tests for StrategyRouter."""
|
|
|
|
@pytest.fixture
|
|
def router(self):
|
|
reset_strategy_router()
|
|
return StrategyRouter()
|
|
|
|
def test_initial_state(self, router):
|
|
"""[AC-AISVC-RES-01] Initial state should be default strategy."""
|
|
assert router.current_strategy == StrategyType.DEFAULT
|
|
assert router.config.enabled is True
|
|
|
|
def test_route_default_strategy(self, router):
|
|
"""[AC-AISVC-RES-01] Should route to default strategy by default."""
|
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
|
|
|
result = router.route(ctx)
|
|
|
|
assert result.strategy == StrategyType.DEFAULT
|
|
assert router.current_strategy == StrategyType.DEFAULT
|
|
|
|
def test_route_enhanced_strategy_explicit(self, router):
|
|
"""[AC-AISVC-RES-02] Should route to enhanced when explicitly configured."""
|
|
router._config.strategy = StrategyType.ENHANCED
|
|
|
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
|
|
|
result = router.route(ctx)
|
|
|
|
assert result.strategy == StrategyType.ENHANCED
|
|
assert router.current_strategy == StrategyType.ENHANCED
|
|
|
|
def test_route_grayscale_allowlist(self, router):
|
|
"""[AC-AISVC-RES-03] Should route to enhanced for tenants in allowlist."""
|
|
router._config.strategy = StrategyType.ENHANCED
|
|
router._config.grayscale_allowlist = ["tenant_a", "tenant_b"]
|
|
|
|
ctx_a = StrategyContext(tenant_id="tenant_a", query="Test query")
|
|
ctx_c = StrategyContext(tenant_id="tenant_c", query="Test query")
|
|
|
|
result_a = router.route(ctx_a)
|
|
result_c = router.route(ctx_c)
|
|
|
|
assert result_a.strategy == StrategyType.ENHANCED
|
|
assert result_c.strategy == StrategyType.ENHANCED
|
|
|
|
def test_route_disabled_strategy(self, router):
|
|
"""Should use default when strategy is disabled."""
|
|
router._strategy_enabled = False
|
|
|
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
|
|
|
result = router.route(ctx)
|
|
|
|
assert result.strategy == StrategyType.DEFAULT
|
|
|
|
def test_update_config(self, router):
|
|
"""[AC-AISVC-RES-15] Should update configuration."""
|
|
new_config = RoutingConfig(
|
|
strategy=StrategyType.ENHANCED,
|
|
rag_runtime_mode=RagRuntimeMode.REACT,
|
|
)
|
|
|
|
router.update_config(new_config)
|
|
|
|
assert router.config.strategy == StrategyType.ENHANCED
|
|
assert router.config.rag_runtime_mode == RagRuntimeMode.REACT
|
|
|
|
def test_rollback(self, router):
|
|
"""[AC-AISVC-RES-07] Should rollback to default strategy."""
|
|
router._current_strategy = StrategyType.ENHANCED
|
|
router._config.strategy = StrategyType.ENHANCED
|
|
|
|
router.rollback(reason="Test rollback", tenant_id="tenant_a")
|
|
|
|
assert router.current_strategy == StrategyType.DEFAULT
|
|
assert router.config.strategy == StrategyType.DEFAULT
|
|
|
|
records = router.get_rollback_records()
|
|
assert len(records) == 1
|
|
assert records[0].reason == "Test rollback"
|
|
|
|
def test_rollback_from_default_no_op(self, router):
|
|
"""[AC-AISVC-RES-07] Rollback from default should be no-op."""
|
|
router.rollback(reason="Test rollback")
|
|
|
|
assert router.current_strategy == StrategyType.DEFAULT
|
|
records = router.get_rollback_records()
|
|
assert len(records) == 0
|
|
|
|
def test_validate_config(self, router):
|
|
"""[AC-AISVC-RES-06] Should validate configuration."""
|
|
is_valid, errors = router.validate_config()
|
|
|
|
assert is_valid is True
|
|
assert len(errors) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_default_strategy(self, router):
|
|
"""[AC-AISVC-RES-01] Should execute default pipeline."""
|
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.hits = []
|
|
|
|
with patch.object(
|
|
router._default_pipeline, "execute", new_callable=AsyncMock
|
|
) as mock_execute:
|
|
mock_execute.return_value = mock_result
|
|
|
|
result, strategy_result = await router.execute(ctx)
|
|
|
|
assert result == mock_result
|
|
assert strategy_result.strategy == StrategyType.DEFAULT
|
|
mock_execute.assert_called_once_with(ctx)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_enhanced_strategy(self, router):
|
|
"""[AC-AISVC-RES-02] Should execute enhanced pipeline."""
|
|
router._config.strategy = StrategyType.ENHANCED
|
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.hits = []
|
|
|
|
with patch.object(
|
|
router._enhanced_pipeline, "execute", new_callable=AsyncMock
|
|
) as mock_execute:
|
|
mock_execute.return_value = mock_result
|
|
|
|
result, strategy_result = await router.execute(ctx)
|
|
|
|
assert result == mock_result
|
|
assert strategy_result.strategy == StrategyType.ENHANCED
|
|
mock_execute.assert_called_once_with(ctx)
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_fallback_on_error(self, router):
|
|
"""[AC-AISVC-RES-07] Should fallback to default on enhanced failure."""
|
|
router._config.strategy = StrategyType.ENHANCED
|
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.hits = []
|
|
|
|
with patch.object(
|
|
router._enhanced_pipeline, "execute", new_callable=AsyncMock
|
|
) as mock_enhanced:
|
|
mock_enhanced.side_effect = Exception("Enhanced pipeline failed")
|
|
|
|
with patch.object(
|
|
router._default_pipeline, "execute", new_callable=AsyncMock
|
|
) as mock_default:
|
|
mock_default.return_value = mock_result
|
|
|
|
result, strategy_result = await router.execute(ctx)
|
|
|
|
assert result == mock_result
|
|
assert strategy_result.strategy == StrategyType.DEFAULT
|
|
assert strategy_result.should_fallback is True
|
|
assert "Enhanced pipeline failed" in strategy_result.fallback_reason
|
|
|
|
|
|
class TestSingletonInstances:
|
|
"""Tests for singleton instance getters."""
|
|
|
|
def test_get_strategy_router_singleton(self):
|
|
"""Should return same router instance."""
|
|
reset_strategy_router()
|
|
|
|
router1 = get_strategy_router()
|
|
router2 = get_strategy_router()
|
|
|
|
assert router1 is router2
|
|
|
|
def test_reset_strategy_router(self):
|
|
"""Should create new instance after reset."""
|
|
router1 = get_strategy_router()
|
|
reset_strategy_router()
|
|
router2 = get_strategy_router()
|
|
|
|
assert router1 is not router2
|