257 lines
9.2 KiB
Python
257 lines
9.2 KiB
Python
"""
|
|
Unit tests for Retrieval Strategy Integration.
|
|
[AC-AISVC-RES-01~15] Tests for integrated strategy and mode routing.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from app.services.retrieval.routing_config import (
|
|
RagRuntimeMode,
|
|
StrategyType,
|
|
RoutingConfig,
|
|
StrategyContext,
|
|
)
|
|
from app.services.retrieval.strategy_integration import (
|
|
RetrievalStrategyResult,
|
|
RetrievalStrategyIntegration,
|
|
get_retrieval_strategy_integration,
|
|
reset_retrieval_strategy_integration,
|
|
)
|
|
|
|
|
|
class TestRetrievalStrategyResult:
|
|
"""Tests for RetrievalStrategyResult."""
|
|
|
|
def test_result_creation(self):
|
|
"""Should create result with all fields."""
|
|
result = RetrievalStrategyResult(
|
|
retrieval_result=None,
|
|
final_answer="Test answer",
|
|
strategy=StrategyType.ENHANCED,
|
|
mode=RagRuntimeMode.REACT,
|
|
should_fallback=True,
|
|
fallback_reason="Low confidence",
|
|
diagnostics={"key": "value"},
|
|
duration_ms=100,
|
|
)
|
|
|
|
assert result.retrieval_result is None
|
|
assert result.final_answer == "Test answer"
|
|
assert result.strategy == StrategyType.ENHANCED
|
|
assert result.mode == RagRuntimeMode.REACT
|
|
assert result.should_fallback is True
|
|
assert result.fallback_reason == "Low confidence"
|
|
assert result.diagnostics == {"key": "value"}
|
|
assert result.duration_ms == 100
|
|
|
|
def test_result_defaults(self):
|
|
"""Should create result with default values."""
|
|
result = RetrievalStrategyResult(
|
|
retrieval_result=None,
|
|
final_answer=None,
|
|
strategy=StrategyType.DEFAULT,
|
|
mode=RagRuntimeMode.DIRECT,
|
|
)
|
|
|
|
assert result.should_fallback is False
|
|
assert result.fallback_reason is None
|
|
assert result.mode_route_result is None
|
|
assert result.diagnostics == {}
|
|
assert result.duration_ms == 0
|
|
|
|
|
|
class TestRetrievalStrategyIntegration:
|
|
"""[AC-AISVC-RES-01~15] Tests for RetrievalStrategyIntegration."""
|
|
|
|
@pytest.fixture
|
|
def integration(self):
|
|
reset_retrieval_strategy_integration()
|
|
return RetrievalStrategyIntegration()
|
|
|
|
def test_initial_state(self, integration):
|
|
"""Should initialize with default configuration."""
|
|
assert integration.config.strategy == StrategyType.DEFAULT
|
|
assert integration.config.rag_runtime_mode == RagRuntimeMode.AUTO
|
|
|
|
def test_update_config(self, integration):
|
|
"""[AC-AISVC-RES-15] Should update all configurations."""
|
|
new_config = RoutingConfig(
|
|
strategy=StrategyType.ENHANCED,
|
|
rag_runtime_mode=RagRuntimeMode.REACT,
|
|
react_max_steps=7,
|
|
)
|
|
|
|
integration.update_config(new_config)
|
|
|
|
assert integration.config.strategy == StrategyType.ENHANCED
|
|
assert integration.config.rag_runtime_mode == RagRuntimeMode.REACT
|
|
|
|
def test_get_current_strategy(self, integration):
|
|
"""Should return current strategy from router."""
|
|
strategy = integration.get_current_strategy()
|
|
|
|
assert strategy == StrategyType.DEFAULT
|
|
|
|
def test_get_rollback_records(self, integration):
|
|
"""Should return rollback records from router."""
|
|
records = integration.get_rollback_records()
|
|
|
|
assert isinstance(records, list)
|
|
|
|
def test_validate_config(self, integration):
|
|
"""[AC-AISVC-RES-06] Should validate configuration."""
|
|
is_valid, errors = integration.validate_config()
|
|
|
|
assert is_valid is True
|
|
assert len(errors) == 0
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_direct_mode(self, integration):
|
|
"""[AC-AISVC-RES-09] Should execute direct mode."""
|
|
integration._config.rag_runtime_mode = RagRuntimeMode.DIRECT
|
|
|
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.hits = []
|
|
|
|
mock_route_result = MagicMock()
|
|
mock_route_result.mode = RagRuntimeMode.DIRECT
|
|
|
|
with patch.object(
|
|
integration._mode_router, "route", return_value=mock_route_result
|
|
):
|
|
with patch.object(
|
|
integration._mode_router, "execute_with_fallback", new_callable=AsyncMock
|
|
) as mock_execute:
|
|
mock_execute.return_value = (mock_result, None, mock_route_result)
|
|
|
|
result = await integration.execute(ctx)
|
|
|
|
assert result.retrieval_result == mock_result
|
|
assert result.final_answer is None
|
|
assert result.mode == RagRuntimeMode.DIRECT
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_react_mode(self, integration):
|
|
"""[AC-AISVC-RES-10] Should execute react mode."""
|
|
integration._config.rag_runtime_mode = RagRuntimeMode.REACT
|
|
|
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
|
|
|
mock_route_result = MagicMock()
|
|
mock_route_result.mode = RagRuntimeMode.REACT
|
|
|
|
with patch.object(
|
|
integration._mode_router, "route", return_value=mock_route_result
|
|
):
|
|
with patch.object(
|
|
integration._mode_router, "execute_react", new_callable=AsyncMock
|
|
) as mock_execute:
|
|
mock_execute.return_value = ("Final answer", None, {})
|
|
|
|
result = await integration.execute(ctx)
|
|
|
|
assert result.retrieval_result is None
|
|
assert result.final_answer == "Final answer"
|
|
assert result.mode == RagRuntimeMode.REACT
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_with_fallback(self, integration):
|
|
"""[AC-AISVC-RES-14] Should handle fallback from direct to react."""
|
|
integration._config.rag_runtime_mode = RagRuntimeMode.DIRECT
|
|
integration._config.direct_fallback_on_low_confidence = True
|
|
integration._config.direct_fallback_confidence_threshold = 0.4
|
|
|
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
|
|
|
mock_route_result = MagicMock()
|
|
mock_route_result.mode = RagRuntimeMode.DIRECT
|
|
mock_route_result.should_fallback_to_react = True
|
|
mock_route_result.fallback_reason = "low_confidence"
|
|
|
|
with patch.object(
|
|
integration._mode_router, "route", return_value=mock_route_result
|
|
):
|
|
with patch.object(
|
|
integration._mode_router, "execute_with_fallback", new_callable=AsyncMock
|
|
) as mock_execute:
|
|
mock_execute.return_value = (None, "Fallback answer", mock_route_result)
|
|
|
|
result = await integration.execute(ctx)
|
|
|
|
assert result.retrieval_result is None
|
|
assert result.final_answer == "Fallback answer"
|
|
assert result.should_fallback is True
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_includes_diagnostics(self, integration):
|
|
"""Should include diagnostics in result."""
|
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.hits = []
|
|
|
|
mock_route_result = MagicMock()
|
|
mock_route_result.mode = RagRuntimeMode.DIRECT
|
|
mock_route_result.diagnostics = {"mode_key": "mode_value"}
|
|
|
|
with patch.object(
|
|
integration._mode_router, "route", return_value=mock_route_result
|
|
):
|
|
with patch.object(
|
|
integration._mode_router, "execute_with_fallback", new_callable=AsyncMock
|
|
) as mock_execute:
|
|
mock_execute.return_value = (mock_result, None, mock_route_result)
|
|
|
|
result = await integration.execute(ctx)
|
|
|
|
assert "strategy_diagnostics" in result.diagnostics
|
|
assert "mode_diagnostics" in result.diagnostics
|
|
assert "duration_ms" in result.diagnostics
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_tracks_duration(self, integration):
|
|
"""Should track execution duration."""
|
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.hits = []
|
|
|
|
mock_route_result = MagicMock()
|
|
mock_route_result.mode = RagRuntimeMode.DIRECT
|
|
|
|
with patch.object(
|
|
integration._mode_router, "route", return_value=mock_route_result
|
|
):
|
|
with patch.object(
|
|
integration._mode_router, "execute_with_fallback", new_callable=AsyncMock
|
|
) as mock_execute:
|
|
mock_execute.return_value = (mock_result, None, mock_route_result)
|
|
|
|
result = await integration.execute(ctx)
|
|
|
|
assert result.duration_ms >= 0
|
|
|
|
|
|
class TestSingletonInstances:
|
|
"""Tests for singleton instance getters."""
|
|
|
|
def test_get_retrieval_strategy_integration_singleton(self):
|
|
"""Should return same integration instance."""
|
|
reset_retrieval_strategy_integration()
|
|
|
|
integration1 = get_retrieval_strategy_integration()
|
|
integration2 = get_retrieval_strategy_integration()
|
|
|
|
assert integration1 is integration2
|
|
|
|
def test_reset_retrieval_strategy_integration(self):
|
|
"""Should create new instance after reset."""
|
|
integration1 = get_retrieval_strategy_integration()
|
|
reset_retrieval_strategy_integration()
|
|
integration2 = get_retrieval_strategy_integration()
|
|
|
|
assert integration1 is not integration2
|