ai-robot-core/ai-service/tests/test_mode_router.py

298 lines
9.7 KiB
Python

"""
Unit tests for Mode Router.
[AC-AISVC-RES-09,10,11,12,13,14] Tests for mode routing, complexity analysis, and fallback.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.services.retrieval.routing_config import (
RagRuntimeMode,
RoutingConfig,
StrategyContext,
)
from app.services.retrieval.mode_router import (
ComplexityAnalyzer,
ModeRouteResult,
ModeRouter,
get_mode_router,
reset_mode_router,
)
class TestComplexityAnalyzer:
"""[AC-AISVC-RES-12,13] Tests for ComplexityAnalyzer."""
@pytest.fixture
def analyzer(self):
return ComplexityAnalyzer()
def test_analyze_empty_query(self, analyzer):
"""Empty query should have zero complexity."""
score = analyzer.analyze("")
assert score == 0.0
def test_analyze_short_query(self, analyzer):
"""Short query should have low complexity."""
score = analyzer.analyze("简单问题")
assert score < 0.3
def test_analyze_long_query(self, analyzer):
"""Long query should have higher complexity."""
long_query = "这是一个很长的问题" * 20
score = analyzer.analyze(long_query)
assert score >= 0.3
def test_analyze_multiple_conditions(self, analyzer):
"""[AC-AISVC-RES-13] Query with multiple conditions should have high complexity."""
query = "查询订单状态和物流信息以及退款进度"
score = analyzer.analyze(query)
assert score >= 0.2
def test_analyze_reasoning_patterns(self, analyzer):
"""Query with reasoning patterns should have higher complexity."""
query = "为什么订单会被取消?原因是什么?"
score = analyzer.analyze(query)
assert score >= 0.15
def test_analyze_cross_domain_patterns(self, analyzer):
"""Cross-domain query should have higher complexity."""
query = "汇总所有部门的销售数据"
score = analyzer.analyze(query)
assert score >= 0.15
def test_analyze_multiple_questions(self, analyzer):
"""Multiple questions should increase complexity."""
query = "什么是价格?如何购买?"
score = analyzer.analyze(query)
assert score >= 0.1
class TestModeRouteResult:
"""Tests for ModeRouteResult."""
def test_result_creation(self):
"""Should create result with all fields."""
result = ModeRouteResult(
mode=RagRuntimeMode.REACT,
confidence=0.5,
complexity_score=0.7,
should_fallback_to_react=True,
fallback_reason="Low confidence",
diagnostics={"key": "value"},
)
assert result.mode == RagRuntimeMode.REACT
assert result.confidence == 0.5
assert result.complexity_score == 0.7
assert result.should_fallback_to_react is True
assert result.fallback_reason == "Low confidence"
assert result.diagnostics == {"key": "value"}
def test_result_defaults(self):
"""Should create result with default values."""
result = ModeRouteResult(
mode=RagRuntimeMode.DIRECT,
confidence=0.8,
complexity_score=0.2,
)
assert result.should_fallback_to_react is False
assert result.fallback_reason is None
assert result.diagnostics == {}
class TestModeRouter:
"""[AC-AISVC-RES-09,10,11,12,13,14] Tests for ModeRouter."""
@pytest.fixture
def router(self):
reset_mode_router()
return ModeRouter()
def test_initial_config(self, router):
"""Should initialize with default configuration."""
assert router.config.rag_runtime_mode == RagRuntimeMode.AUTO
def test_route_direct_mode(self, router):
"""[AC-AISVC-RES-09] Should route to direct mode when configured."""
router._config.rag_runtime_mode = RagRuntimeMode.DIRECT
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
result = router.route(ctx)
assert result.mode == RagRuntimeMode.DIRECT
def test_route_react_mode(self, router):
"""[AC-AISVC-RES-10] Should route to react mode when configured."""
router._config.rag_runtime_mode = RagRuntimeMode.REACT
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
result = router.route(ctx)
assert result.mode == RagRuntimeMode.REACT
def test_route_auto_mode_direct_conditions(self, router):
"""[AC-AISVC-RES-11,12] Auto mode should select direct for simple queries."""
router._config.rag_runtime_mode = RagRuntimeMode.AUTO
ctx = StrategyContext(
tenant_id="tenant_a",
query="简单问题",
metadata_confidence=0.9,
complexity_score=0.1,
)
result = router.route(ctx)
assert result.mode == RagRuntimeMode.DIRECT
def test_route_auto_mode_react_low_confidence(self, router):
"""[AC-AISVC-RES-11,12] Auto mode should select react for low confidence."""
router._config.rag_runtime_mode = RagRuntimeMode.AUTO
router._config.react_trigger_confidence_threshold = 0.6
ctx = StrategyContext(
tenant_id="tenant_a",
query="Test query",
metadata_confidence=0.4,
complexity_score=0.2,
)
result = router.route(ctx)
assert result.mode == RagRuntimeMode.REACT
def test_route_auto_mode_react_high_complexity(self, router):
"""[AC-AISVC-RES-11,13] Auto mode should select react for high complexity."""
router._config.rag_runtime_mode = RagRuntimeMode.AUTO
router._config.react_trigger_complexity_score = 0.5
ctx = StrategyContext(
tenant_id="tenant_a",
query="Test query",
metadata_confidence=0.8,
complexity_score=0.7,
)
result = router.route(ctx)
assert result.mode == RagRuntimeMode.REACT
def test_update_config(self, router):
"""[AC-AISVC-RES-15] Should update configuration."""
new_config = RoutingConfig(
rag_runtime_mode=RagRuntimeMode.REACT,
react_max_steps=7,
)
router.update_config(new_config)
assert router.config.rag_runtime_mode == RagRuntimeMode.REACT
assert router.config.react_max_steps == 7
@pytest.mark.asyncio
async def test_execute_direct(self, router):
"""[AC-AISVC-RES-09] Should execute direct retrieval."""
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
mock_result = MagicMock()
mock_result.hits = []
with patch.object(
router._direct_executor, "execute", new_callable=AsyncMock
) as mock_execute:
mock_execute.return_value = mock_result
result = await router.execute_direct(ctx)
assert result == mock_result
@pytest.mark.asyncio
async def test_execute_with_fallback_no_fallback(self, router):
"""[AC-AISVC-RES-14] Should not fallback when confidence is high."""
router._config.rag_runtime_mode = RagRuntimeMode.DIRECT
router._config.direct_fallback_on_low_confidence = True
router._config.direct_fallback_confidence_threshold = 0.4
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
mock_result = MagicMock()
mock_hit = MagicMock()
mock_hit.score = 0.8
mock_result.hits = [mock_hit]
with patch.object(
router._direct_executor, "execute", new_callable=AsyncMock
) as mock_execute:
mock_execute.return_value = mock_result
result, answer, route_result = await router.execute_with_fallback(ctx)
assert result == mock_result
assert answer is None
assert route_result.mode == RagRuntimeMode.DIRECT
assert route_result.should_fallback_to_react is False
@pytest.mark.asyncio
async def test_execute_with_fallback_triggered(self, router):
"""[AC-AISVC-RES-14] Should fallback to react when confidence is low."""
router._config.rag_runtime_mode = RagRuntimeMode.DIRECT
router._config.direct_fallback_on_low_confidence = True
router._config.direct_fallback_confidence_threshold = 0.4
ctx = StrategyContext(tenant_id="tenant_a", query="Test query")
mock_result = MagicMock()
mock_hit = MagicMock()
mock_hit.score = 0.2
mock_result.hits = [mock_hit]
with patch.object(
router._direct_executor, "execute", new_callable=AsyncMock
) as mock_direct:
mock_direct.return_value = mock_result
with patch.object(
router._react_executor, "execute", new_callable=AsyncMock
) as mock_react:
mock_react.return_value = ("Final answer", None, {})
result, answer, route_result = await router.execute_with_fallback(ctx)
assert result is None
assert answer == "Final answer"
assert route_result.mode == RagRuntimeMode.REACT
assert route_result.should_fallback_to_react is True
assert route_result.fallback_reason == "low_confidence"
class TestSingletonInstances:
"""Tests for singleton instance getters."""
def test_get_mode_router_singleton(self):
"""Should return same router instance."""
reset_mode_router()
router1 = get_mode_router()
router2 = get_mode_router()
assert router1 is router2
def test_reset_mode_router(self):
"""Should create new instance after reset."""
router1 = get_mode_router()
reset_mode_router()
router2 = get_mode_router()
assert router1 is not router2