ai-robot-core/ai-service/tests/test_strategy_router.py

345 lines
12 KiB
Python
Raw Permalink Normal View History

"""
Unit tests for Strategy Router.
[AC-AISVC-RES-01,02,03,07,08] Tests for strategy routing, rollback, and grayscale release.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.services.retrieval.routing_config import (
RagRuntimeMode,
StrategyType,
RoutingConfig,
StrategyContext,
)
from app.services.retrieval.strategy_router import (
RollbackRecord,
RollbackManager,
DefaultPipeline,
EnhancedPipeline,
StrategyRouter,
get_strategy_router,
reset_strategy_router,
)
class TestRollbackRecord:
"""[AC-AISVC-RES-07] Tests for RollbackRecord."""
def test_rollback_record_creation(self):
"""Should create rollback record with all fields."""
record = RollbackRecord(
timestamp=1234567890.0,
from_strategy=StrategyType.ENHANCED,
to_strategy=StrategyType.DEFAULT,
reason="Performance issue",
tenant_id="tenant_a",
request_id="req_123",
)
assert record.timestamp == 1234567890.0
assert record.from_strategy == StrategyType.ENHANCED
assert record.to_strategy == StrategyType.DEFAULT
assert record.reason == "Performance issue"
assert record.tenant_id == "tenant_a"
assert record.request_id == "req_123"
class TestRollbackManager:
"""[AC-AISVC-RES-07] Tests for RollbackManager."""
@pytest.fixture
def manager(self):
return RollbackManager(max_records=10)
def test_record_rollback(self, manager):
"""[AC-AISVC-RES-07] Should record rollback event."""
manager.record_rollback(
from_strategy=StrategyType.ENHANCED,
to_strategy=StrategyType.DEFAULT,
reason="Test rollback",
tenant_id="tenant_a",
)
records = manager.get_recent_rollbacks()
assert len(records) == 1
assert records[0].from_strategy == StrategyType.ENHANCED
assert records[0].to_strategy == StrategyType.DEFAULT
assert records[0].reason == "Test rollback"
def test_max_records_limit(self, manager):
"""Should limit number of records."""
for i in range(15):
manager.record_rollback(
from_strategy=StrategyType.ENHANCED,
to_strategy=StrategyType.DEFAULT,
reason=f"Reason {i}",
)
records = manager.get_recent_rollbacks(limit=20)
assert len(records) == 10
def test_get_rollback_count(self, manager):
"""Should count rollbacks correctly."""
import time
now = time.time()
manager.record_rollback(
from_strategy=StrategyType.ENHANCED,
to_strategy=StrategyType.DEFAULT,
reason="Reason 1",
)
assert manager.get_rollback_count() == 1
assert manager.get_rollback_count(since_timestamp=now) == 1
assert manager.get_rollback_count(since_timestamp=now + 10) == 0
class TestDefaultPipeline:
"""[AC-AISVC-RES-01] Tests for DefaultPipeline."""
@pytest.mark.asyncio
async def test_execute_uses_optimized_retriever(self):
"""[AC-AISVC-RES-01] Default pipeline should use OptimizedRetriever."""
pipeline = DefaultPipeline()
ctx = StrategyContext(
tenant_id="tenant_a",
query="Test query",
)
mock_result = MagicMock()
mock_result.hits = []
with patch(
"app.services.retrieval.optimized_retriever.get_optimized_retriever",
new_callable=AsyncMock,
) as mock_get_retriever:
mock_retriever = AsyncMock()
mock_retriever.retrieve.return_value = mock_result
mock_get_retriever.return_value = mock_retriever
result = await pipeline.execute(ctx)
assert result == mock_result
mock_retriever.retrieve.assert_called_once()
class TestEnhancedPipeline:
"""[AC-AISVC-RES-02] Tests for EnhancedPipeline."""
@pytest.mark.asyncio
async def test_execute_creates_optimized_retriever(self):
"""[AC-AISVC-RES-02] Enhanced pipeline should create OptimizedRetriever with enhanced config."""
config = RoutingConfig(strategy=StrategyType.ENHANCED)
pipeline = EnhancedPipeline(config=config)
ctx = StrategyContext(
tenant_id="tenant_a",
query="Test query",
)
mock_result = MagicMock()
mock_result.hits = []
mock_retriever = AsyncMock()
mock_retriever.retrieve.return_value = mock_result
with patch(
"app.services.retrieval.optimized_retriever.OptimizedRetriever",
return_value=mock_retriever,
) as mock_retriever_class:
result = await pipeline.execute(ctx)
mock_retriever_class.assert_called_once_with(
two_stage_enabled=True,
hybrid_enabled=True,
)
assert result == mock_result
class TestStrategyRouter:
"""[AC-AISVC-RES-01,02,03,07,08] Tests for StrategyRouter."""
@pytest.fixture
def router(self):
reset_strategy_router()
return StrategyRouter()
def test_initial_state(self, router):
"""[AC-AISVC-RES-01] Initial state should be default strategy."""
assert router.current_strategy == StrategyType.DEFAULT
assert router.config.enabled is True
def test_route_default_strategy(self, router):
"""[AC-AISVC-RES-01] Should route to default strategy by default."""
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
result = router.route(ctx)
assert result.strategy == StrategyType.DEFAULT
assert router.current_strategy == StrategyType.DEFAULT
def test_route_enhanced_strategy_explicit(self, router):
"""[AC-AISVC-RES-02] Should route to enhanced when explicitly configured."""
router._config.strategy = StrategyType.ENHANCED
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
result = router.route(ctx)
assert result.strategy == StrategyType.ENHANCED
assert router.current_strategy == StrategyType.ENHANCED
def test_route_grayscale_allowlist(self, router):
"""[AC-AISVC-RES-03] Should route to enhanced for tenants in allowlist."""
router._config.strategy = StrategyType.ENHANCED
router._config.grayscale_allowlist = ["tenant_a", "tenant_b"]
ctx_a = StrategyContext(tenant_id="tenant_a", query="Test query")
ctx_c = StrategyContext(tenant_id="tenant_c", query="Test query")
result_a = router.route(ctx_a)
result_c = router.route(ctx_c)
assert result_a.strategy == StrategyType.ENHANCED
assert result_c.strategy == StrategyType.ENHANCED
def test_route_disabled_strategy(self, router):
"""Should use default when strategy is disabled."""
router._strategy_enabled = False
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
result = router.route(ctx)
assert result.strategy == StrategyType.DEFAULT
def test_update_config(self, router):
"""[AC-AISVC-RES-15] Should update configuration."""
new_config = RoutingConfig(
strategy=StrategyType.ENHANCED,
rag_runtime_mode=RagRuntimeMode.REACT,
)
router.update_config(new_config)
assert router.config.strategy == StrategyType.ENHANCED
assert router.config.rag_runtime_mode == RagRuntimeMode.REACT
def test_rollback(self, router):
"""[AC-AISVC-RES-07] Should rollback to default strategy."""
router._current_strategy = StrategyType.ENHANCED
router._config.strategy = StrategyType.ENHANCED
router.rollback(reason="Test rollback", tenant_id="tenant_a")
assert router.current_strategy == StrategyType.DEFAULT
assert router.config.strategy == StrategyType.DEFAULT
records = router.get_rollback_records()
assert len(records) == 1
assert records[0].reason == "Test rollback"
def test_rollback_from_default_no_op(self, router):
"""[AC-AISVC-RES-07] Rollback from default should be no-op."""
router.rollback(reason="Test rollback")
assert router.current_strategy == StrategyType.DEFAULT
records = router.get_rollback_records()
assert len(records) == 0
def test_validate_config(self, router):
"""[AC-AISVC-RES-06] Should validate configuration."""
is_valid, errors = router.validate_config()
assert is_valid is True
assert len(errors) == 0
@pytest.mark.asyncio
async def test_execute_default_strategy(self, router):
"""[AC-AISVC-RES-01] Should execute default pipeline."""
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
mock_result = MagicMock()
mock_result.hits = []
with patch.object(
router._default_pipeline, "execute", new_callable=AsyncMock
) as mock_execute:
mock_execute.return_value = mock_result
result, strategy_result = await router.execute(ctx)
assert result == mock_result
assert strategy_result.strategy == StrategyType.DEFAULT
mock_execute.assert_called_once_with(ctx)
@pytest.mark.asyncio
async def test_execute_enhanced_strategy(self, router):
"""[AC-AISVC-RES-02] Should execute enhanced pipeline."""
router._config.strategy = StrategyType.ENHANCED
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
mock_result = MagicMock()
mock_result.hits = []
with patch.object(
router._enhanced_pipeline, "execute", new_callable=AsyncMock
) as mock_execute:
mock_execute.return_value = mock_result
result, strategy_result = await router.execute(ctx)
assert result == mock_result
assert strategy_result.strategy == StrategyType.ENHANCED
mock_execute.assert_called_once_with(ctx)
@pytest.mark.asyncio
async def test_execute_fallback_on_error(self, router):
"""[AC-AISVC-RES-07] Should fallback to default on enhanced failure."""
router._config.strategy = StrategyType.ENHANCED
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
mock_result = MagicMock()
mock_result.hits = []
with patch.object(
router._enhanced_pipeline, "execute", new_callable=AsyncMock
) as mock_enhanced:
mock_enhanced.side_effect = Exception("Enhanced pipeline failed")
with patch.object(
router._default_pipeline, "execute", new_callable=AsyncMock
) as mock_default:
mock_default.return_value = mock_result
result, strategy_result = await router.execute(ctx)
assert result == mock_result
assert strategy_result.strategy == StrategyType.DEFAULT
assert strategy_result.should_fallback is True
assert "Enhanced pipeline failed" in strategy_result.fallback_reason
class TestSingletonInstances:
"""Tests for singleton instance getters."""
def test_get_strategy_router_singleton(self):
"""Should return same router instance."""
reset_strategy_router()
router1 = get_strategy_router()
router2 = get_strategy_router()
assert router1 is router2
def test_reset_strategy_router(self):
"""Should create new instance after reset."""
router1 = get_strategy_router()
reset_strategy_router()
router2 = get_strategy_router()
assert router1 is not router2