298 lines
9.7 KiB
Python
298 lines
9.7 KiB
Python
"""
|
|
Unit tests for Mode Router.
|
|
[AC-AISVC-RES-09,10,11,12,13,14] Tests for mode routing, complexity analysis, and fallback.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
from app.services.retrieval.routing_config import (
|
|
RagRuntimeMode,
|
|
RoutingConfig,
|
|
StrategyContext,
|
|
)
|
|
from app.services.retrieval.mode_router import (
|
|
ComplexityAnalyzer,
|
|
ModeRouteResult,
|
|
ModeRouter,
|
|
get_mode_router,
|
|
reset_mode_router,
|
|
)
|
|
|
|
|
|
class TestComplexityAnalyzer:
|
|
"""[AC-AISVC-RES-12,13] Tests for ComplexityAnalyzer."""
|
|
|
|
@pytest.fixture
|
|
def analyzer(self):
|
|
return ComplexityAnalyzer()
|
|
|
|
def test_analyze_empty_query(self, analyzer):
|
|
"""Empty query should have zero complexity."""
|
|
score = analyzer.analyze("")
|
|
|
|
assert score == 0.0
|
|
|
|
def test_analyze_short_query(self, analyzer):
|
|
"""Short query should have low complexity."""
|
|
score = analyzer.analyze("简单问题")
|
|
|
|
assert score < 0.3
|
|
|
|
def test_analyze_long_query(self, analyzer):
|
|
"""Long query should have higher complexity."""
|
|
long_query = "这是一个很长的问题" * 20
|
|
score = analyzer.analyze(long_query)
|
|
|
|
assert score >= 0.3
|
|
|
|
def test_analyze_multiple_conditions(self, analyzer):
|
|
"""[AC-AISVC-RES-13] Query with multiple conditions should have high complexity."""
|
|
query = "查询订单状态和物流信息以及退款进度"
|
|
score = analyzer.analyze(query)
|
|
|
|
assert score >= 0.2
|
|
|
|
def test_analyze_reasoning_patterns(self, analyzer):
|
|
"""Query with reasoning patterns should have higher complexity."""
|
|
query = "为什么订单会被取消?原因是什么?"
|
|
score = analyzer.analyze(query)
|
|
|
|
assert score >= 0.15
|
|
|
|
def test_analyze_cross_domain_patterns(self, analyzer):
|
|
"""Cross-domain query should have higher complexity."""
|
|
query = "汇总所有部门的销售数据"
|
|
score = analyzer.analyze(query)
|
|
|
|
assert score >= 0.15
|
|
|
|
def test_analyze_multiple_questions(self, analyzer):
|
|
"""Multiple questions should increase complexity."""
|
|
query = "什么是价格?如何购买?"
|
|
score = analyzer.analyze(query)
|
|
|
|
assert score >= 0.1
|
|
|
|
|
|
class TestModeRouteResult:
|
|
"""Tests for ModeRouteResult."""
|
|
|
|
def test_result_creation(self):
|
|
"""Should create result with all fields."""
|
|
result = ModeRouteResult(
|
|
mode=RagRuntimeMode.REACT,
|
|
confidence=0.5,
|
|
complexity_score=0.7,
|
|
should_fallback_to_react=True,
|
|
fallback_reason="Low confidence",
|
|
diagnostics={"key": "value"},
|
|
)
|
|
|
|
assert result.mode == RagRuntimeMode.REACT
|
|
assert result.confidence == 0.5
|
|
assert result.complexity_score == 0.7
|
|
assert result.should_fallback_to_react is True
|
|
assert result.fallback_reason == "Low confidence"
|
|
assert result.diagnostics == {"key": "value"}
|
|
|
|
def test_result_defaults(self):
|
|
"""Should create result with default values."""
|
|
result = ModeRouteResult(
|
|
mode=RagRuntimeMode.DIRECT,
|
|
confidence=0.8,
|
|
complexity_score=0.2,
|
|
)
|
|
|
|
assert result.should_fallback_to_react is False
|
|
assert result.fallback_reason is None
|
|
assert result.diagnostics == {}
|
|
|
|
|
|
class TestModeRouter:
|
|
"""[AC-AISVC-RES-09,10,11,12,13,14] Tests for ModeRouter."""
|
|
|
|
@pytest.fixture
|
|
def router(self):
|
|
reset_mode_router()
|
|
return ModeRouter()
|
|
|
|
def test_initial_config(self, router):
|
|
"""Should initialize with default configuration."""
|
|
assert router.config.rag_runtime_mode == RagRuntimeMode.AUTO
|
|
|
|
def test_route_direct_mode(self, router):
|
|
"""[AC-AISVC-RES-09] Should route to direct mode when configured."""
|
|
router._config.rag_runtime_mode = RagRuntimeMode.DIRECT
|
|
|
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
|
|
|
result = router.route(ctx)
|
|
|
|
assert result.mode == RagRuntimeMode.DIRECT
|
|
|
|
def test_route_react_mode(self, router):
|
|
"""[AC-AISVC-RES-10] Should route to react mode when configured."""
|
|
router._config.rag_runtime_mode = RagRuntimeMode.REACT
|
|
|
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
|
|
|
result = router.route(ctx)
|
|
|
|
assert result.mode == RagRuntimeMode.REACT
|
|
|
|
def test_route_auto_mode_direct_conditions(self, router):
|
|
"""[AC-AISVC-RES-11,12] Auto mode should select direct for simple queries."""
|
|
router._config.rag_runtime_mode = RagRuntimeMode.AUTO
|
|
|
|
ctx = StrategyContext(
|
|
tenant_id="tenant_a",
|
|
query="简单问题",
|
|
metadata_confidence=0.9,
|
|
complexity_score=0.1,
|
|
)
|
|
|
|
result = router.route(ctx)
|
|
|
|
assert result.mode == RagRuntimeMode.DIRECT
|
|
|
|
def test_route_auto_mode_react_low_confidence(self, router):
|
|
"""[AC-AISVC-RES-11,12] Auto mode should select react for low confidence."""
|
|
router._config.rag_runtime_mode = RagRuntimeMode.AUTO
|
|
router._config.react_trigger_confidence_threshold = 0.6
|
|
|
|
ctx = StrategyContext(
|
|
tenant_id="tenant_a",
|
|
query="Test query",
|
|
metadata_confidence=0.4,
|
|
complexity_score=0.2,
|
|
)
|
|
|
|
result = router.route(ctx)
|
|
|
|
assert result.mode == RagRuntimeMode.REACT
|
|
|
|
def test_route_auto_mode_react_high_complexity(self, router):
|
|
"""[AC-AISVC-RES-11,13] Auto mode should select react for high complexity."""
|
|
router._config.rag_runtime_mode = RagRuntimeMode.AUTO
|
|
router._config.react_trigger_complexity_score = 0.5
|
|
|
|
ctx = StrategyContext(
|
|
tenant_id="tenant_a",
|
|
query="Test query",
|
|
metadata_confidence=0.8,
|
|
complexity_score=0.7,
|
|
)
|
|
|
|
result = router.route(ctx)
|
|
|
|
assert result.mode == RagRuntimeMode.REACT
|
|
|
|
def test_update_config(self, router):
|
|
"""[AC-AISVC-RES-15] Should update configuration."""
|
|
new_config = RoutingConfig(
|
|
rag_runtime_mode=RagRuntimeMode.REACT,
|
|
react_max_steps=7,
|
|
)
|
|
|
|
router.update_config(new_config)
|
|
|
|
assert router.config.rag_runtime_mode == RagRuntimeMode.REACT
|
|
assert router.config.react_max_steps == 7
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_direct(self, router):
|
|
"""[AC-AISVC-RES-09] Should execute direct retrieval."""
|
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
|
|
|
mock_result = MagicMock()
|
|
mock_result.hits = []
|
|
|
|
with patch.object(
|
|
router._direct_executor, "execute", new_callable=AsyncMock
|
|
) as mock_execute:
|
|
mock_execute.return_value = mock_result
|
|
|
|
result = await router.execute_direct(ctx)
|
|
|
|
assert result == mock_result
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_with_fallback_no_fallback(self, router):
|
|
"""[AC-AISVC-RES-14] Should not fallback when confidence is high."""
|
|
router._config.rag_runtime_mode = RagRuntimeMode.DIRECT
|
|
router._config.direct_fallback_on_low_confidence = True
|
|
router._config.direct_fallback_confidence_threshold = 0.4
|
|
|
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
|
|
|
mock_result = MagicMock()
|
|
mock_hit = MagicMock()
|
|
mock_hit.score = 0.8
|
|
mock_result.hits = [mock_hit]
|
|
|
|
with patch.object(
|
|
router._direct_executor, "execute", new_callable=AsyncMock
|
|
) as mock_execute:
|
|
mock_execute.return_value = mock_result
|
|
|
|
result, answer, route_result = await router.execute_with_fallback(ctx)
|
|
|
|
assert result == mock_result
|
|
assert answer is None
|
|
assert route_result.mode == RagRuntimeMode.DIRECT
|
|
assert route_result.should_fallback_to_react is False
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_execute_with_fallback_triggered(self, router):
|
|
"""[AC-AISVC-RES-14] Should fallback to react when confidence is low."""
|
|
router._config.rag_runtime_mode = RagRuntimeMode.DIRECT
|
|
router._config.direct_fallback_on_low_confidence = True
|
|
router._config.direct_fallback_confidence_threshold = 0.4
|
|
|
|
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
|
|
|
|
mock_result = MagicMock()
|
|
mock_hit = MagicMock()
|
|
mock_hit.score = 0.2
|
|
mock_result.hits = [mock_hit]
|
|
|
|
with patch.object(
|
|
router._direct_executor, "execute", new_callable=AsyncMock
|
|
) as mock_direct:
|
|
mock_direct.return_value = mock_result
|
|
|
|
with patch.object(
|
|
router._react_executor, "execute", new_callable=AsyncMock
|
|
) as mock_react:
|
|
mock_react.return_value = ("Final answer", None, {})
|
|
|
|
result, answer, route_result = await router.execute_with_fallback(ctx)
|
|
|
|
assert result is None
|
|
assert answer == "Final answer"
|
|
assert route_result.mode == RagRuntimeMode.REACT
|
|
assert route_result.should_fallback_to_react is True
|
|
assert route_result.fallback_reason == "low_confidence"
|
|
|
|
|
|
class TestSingletonInstances:
|
|
"""Tests for singleton instance getters."""
|
|
|
|
def test_get_mode_router_singleton(self):
|
|
"""Should return same router instance."""
|
|
reset_mode_router()
|
|
|
|
router1 = get_mode_router()
|
|
router2 = get_mode_router()
|
|
|
|
assert router1 is router2
|
|
|
|
def test_reset_mode_router(self):
|
|
"""Should create new instance after reset."""
|
|
router1 = get_mode_router()
|
|
reset_mode_router()
|
|
router2 = get_mode_router()
|
|
|
|
assert router1 is not router2
|