ai-robot-core/ai-service/tests/test_retrieval_strategy_v2.py

646 lines
24 KiB
Python

"""
Unit tests for Retrieval Strategy Module.
[AC-AISVC-RES-01~15] Tests for strategy config, pipelines, routers, and rollback.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from dataclasses import asdict
from app.services.retrieval.strategy.config import (
FilterMode,
GrayscaleConfig,
HybridRetrievalConfig,
MetadataInferenceConfig,
ModeRouterConfig,
PipelineConfig,
RerankerConfig,
RetrievalStrategyConfig,
RuntimeMode,
StrategyType,
get_strategy_config,
set_strategy_config,
)
from app.services.retrieval.strategy.pipeline_base import (
BasePipeline,
MetadataFilterResult,
PipelineContext,
PipelineResult,
)
from app.services.retrieval.strategy.default_pipeline import DefaultPipeline
from app.services.retrieval.strategy.enhanced_pipeline import EnhancedPipeline
from app.services.retrieval.strategy.strategy_router import (
RoutingDecision,
StrategyRouter,
get_strategy_router,
)
from app.services.retrieval.strategy.mode_router import (
ModeDecision,
ModeRouter,
get_mode_router,
)
from app.services.retrieval.strategy.rollback_manager import (
AuditLog,
RollbackManager,
RollbackResult,
RollbackTrigger,
get_rollback_manager,
)
from app.services.retrieval.base import RetrievalContext, RetrievalHit, RetrievalResult
class TestStrategyConfig:
"""[AC-AISVC-RES-01~15] Tests for strategy configuration models."""
def test_strategy_type_enum(self):
"""[AC-AISVC-RES-01] Strategy type should have default and enhanced values."""
assert StrategyType.DEFAULT.value == "default"
assert StrategyType.ENHANCED.value == "enhanced"
def test_runtime_mode_enum(self):
"""[AC-AISVC-RES-09] Runtime mode should have direct, react, and auto values."""
assert RuntimeMode.DIRECT.value == "direct"
assert RuntimeMode.REACT.value == "react"
assert RuntimeMode.AUTO.value == "auto"
def test_filter_mode_enum(self):
"""[AC-AISVC-RES-04] Filter mode should have hard, soft, and none values."""
assert FilterMode.HARD.value == "hard"
assert FilterMode.SOFT.value == "soft"
assert FilterMode.NONE.value == "none"
def test_grayscale_config_default(self):
"""[AC-AISVC-RES-03] Default grayscale config should be disabled."""
config = GrayscaleConfig()
assert config.enabled is False
assert config.percentage == 0.0
assert config.allowlist == []
def test_grayscale_config_should_use_enhanced_disabled(self):
"""[AC-AISVC-RES-03] Should not use enhanced when grayscale disabled."""
config = GrayscaleConfig(enabled=False, percentage=50.0)
assert config.should_use_enhanced("tenant_a") is False
def test_grayscale_config_should_use_enhanced_allowlist(self):
"""[AC-AISVC-RES-03] Should use enhanced for tenants in allowlist."""
config = GrayscaleConfig(enabled=True, allowlist=["tenant_a", "tenant_b"])
assert config.should_use_enhanced("tenant_a") is True
assert config.should_use_enhanced("tenant_b") is True
assert config.should_use_enhanced("tenant_c") is False
def test_grayscale_config_should_use_enhanced_percentage(self):
"""[AC-AISVC-RES-03] Should use enhanced based on percentage."""
config = GrayscaleConfig(enabled=True, percentage=100.0)
assert config.should_use_enhanced("any_tenant") is True
config = GrayscaleConfig(enabled=True, percentage=0.0)
assert config.should_use_enhanced("any_tenant") is False
def test_reranker_config_default(self):
"""[AC-AISVC-RES-08] Default reranker config should be disabled."""
config = RerankerConfig()
assert config.enabled is False
assert config.model == "cross-encoder"
assert config.top_k_after_rerank == 5
def test_mode_router_config_default(self):
"""[AC-AISVC-RES-09] Default mode router config should be direct."""
config = ModeRouterConfig()
assert config.runtime_mode == RuntimeMode.DIRECT
assert config.react_trigger_confidence_threshold == 0.6
assert config.react_max_steps == 5
def test_mode_router_config_should_use_react_always(self):
"""[AC-AISVC-RES-10] React mode should always use react."""
config = ModeRouterConfig(runtime_mode=RuntimeMode.REACT)
assert config.should_use_react("any query") is True
def test_mode_router_config_should_use_react_never(self):
"""[AC-AISVC-RES-09] Direct mode should never use react."""
config = ModeRouterConfig(runtime_mode=RuntimeMode.DIRECT)
assert config.should_use_react("any query") is False
def test_mode_router_config_auto_short_query_high_confidence(self):
"""[AC-AISVC-RES-12] Auto mode with short query and high confidence should use direct."""
config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO)
assert config.should_use_react("短问题", confidence=0.8) is False
def test_mode_router_config_auto_low_confidence(self):
"""[AC-AISVC-RES-13] Auto mode with low confidence should use react."""
config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO)
assert config.should_use_react("any query", confidence=0.3) is True
def test_metadata_inference_config_determine_filter_mode(self):
"""[AC-AISVC-RES-04] Should determine filter mode based on confidence."""
config = MetadataInferenceConfig()
assert config.determine_filter_mode(0.9) == FilterMode.HARD
assert config.determine_filter_mode(0.6) == FilterMode.SOFT
assert config.determine_filter_mode(0.3) == FilterMode.NONE
assert config.determine_filter_mode(None) == FilterMode.NONE
def test_pipeline_config_default(self):
"""[AC-AISVC-RES-01] Default pipeline config should have sensible defaults."""
config = PipelineConfig()
assert config.top_k == 5
assert config.score_threshold == 0.01
assert config.two_stage_enabled is True
def test_retrieval_strategy_config_default(self):
"""[AC-AISVC-RES-01] Default strategy config should use default strategy."""
config = RetrievalStrategyConfig()
assert config.active_strategy == StrategyType.DEFAULT
assert config.grayscale.enabled is False
assert config.mode_router.runtime_mode == RuntimeMode.DIRECT
def test_retrieval_strategy_config_is_enhanced_enabled(self):
"""[AC-AISVC-RES-02] Should check if enhanced is enabled."""
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
assert config.is_enhanced_enabled("tenant_a") is True
config = RetrievalStrategyConfig(
active_strategy=StrategyType.DEFAULT,
grayscale=GrayscaleConfig(enabled=True, allowlist=["tenant_a"]),
)
assert config.is_enhanced_enabled("tenant_a") is True
assert config.is_enhanced_enabled("tenant_b") is False
def test_retrieval_strategy_config_to_dict(self):
"""[AC-AISVC-RES-01] Should convert config to dictionary."""
config = RetrievalStrategyConfig()
d = config.to_dict()
assert d["active_strategy"] == "default"
assert "grayscale" in d
assert "pipeline" in d
assert "reranker" in d
assert "mode_router" in d
def test_global_config_functions(self):
"""[AC-AISVC-RES-01] Should get and set global config."""
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
set_strategy_config(config)
retrieved = get_strategy_config()
assert retrieved.active_strategy == StrategyType.ENHANCED
set_strategy_config(RetrievalStrategyConfig())
class TestPipelineBase:
"""[AC-AISVC-RES-01~02] Tests for pipeline base classes."""
def test_metadata_filter_result_default(self):
"""[AC-AISVC-RES-04] Default metadata filter result should be empty."""
result = MetadataFilterResult()
assert result.filter_dict == {}
assert result.filter_mode == FilterMode.NONE
assert result.confidence is None
def test_pipeline_context_properties(self):
"""[AC-AISVC-RES-01] Pipeline context should expose retrieval context properties."""
retrieval_ctx = RetrievalContext(
tenant_id="tenant_1",
query="test query",
session_id="session_1",
kb_ids=["kb_1"],
)
pipeline_ctx = PipelineContext(retrieval_ctx=retrieval_ctx)
assert pipeline_ctx.tenant_id == "tenant_1"
assert pipeline_ctx.query == "test query"
assert pipeline_ctx.session_id == "session_1"
assert pipeline_ctx.kb_ids == ["kb_1"]
def test_pipeline_result_properties(self):
"""[AC-AISVC-RES-01] Pipeline result should expose retrieval result properties."""
hits = [
RetrievalHit(text="hit 1", score=0.9, source="test", metadata={}),
RetrievalHit(text="hit 2", score=0.8, source="test", metadata={}),
]
retrieval_result = RetrievalResult(hits=hits)
pipeline_result = PipelineResult(
retrieval_result=retrieval_result,
pipeline_name="test_pipeline",
)
assert pipeline_result.hits == hits
assert pipeline_result.is_empty is False
assert pipeline_result.pipeline_name == "test_pipeline"
def test_pipeline_result_is_empty(self):
"""[AC-AISVC-RES-01] Pipeline result should detect empty results."""
pipeline_result = PipelineResult(
retrieval_result=RetrievalResult(hits=[]),
)
assert pipeline_result.is_empty is True
class TestDefaultPipeline:
"""[AC-AISVC-RES-01] Tests for default pipeline."""
@pytest.fixture
def mock_retriever(self):
"""Create a mock optimized retriever."""
retriever = AsyncMock()
retriever.retrieve = AsyncMock(return_value=RetrievalResult(
hits=[
RetrievalHit(text="result 1", score=0.9, source="default", metadata={}),
],
diagnostics={"test": True},
))
retriever.health_check = AsyncMock(return_value=True)
retriever._two_stage_enabled = True
retriever._hybrid_enabled = True
return retriever
@pytest.fixture
def pipeline(self, mock_retriever):
"""Create a default pipeline with mock retriever."""
return DefaultPipeline(optimized_retriever=mock_retriever)
def test_pipeline_name(self, pipeline):
"""[AC-AISVC-RES-01] Pipeline should have correct name."""
assert pipeline.name == "default_pipeline"
def test_pipeline_description(self, pipeline):
"""[AC-AISVC-RES-01] Pipeline should have description."""
assert "默认" in pipeline.description
@pytest.mark.asyncio
async def test_retrieve(self, pipeline, mock_retriever):
"""[AC-AISVC-RES-01] Should retrieve results using optimized retriever."""
retrieval_ctx = RetrievalContext(
tenant_id="tenant_1",
query="test query",
)
pipeline_ctx = PipelineContext(retrieval_ctx=retrieval_ctx)
result = await pipeline.retrieve(pipeline_ctx)
assert result.pipeline_name == "default_pipeline"
assert len(result.hits) == 1
assert result.diagnostics["retriever"] == "OptimizedRetriever"
mock_retriever.retrieve.assert_called_once()
@pytest.mark.asyncio
async def test_retrieve_with_metadata_filter(self, pipeline, mock_retriever):
"""[AC-AISVC-RES-04] Should apply metadata filter."""
retrieval_ctx = RetrievalContext(
tenant_id="tenant_1",
query="test query",
)
metadata_filter = MetadataFilterResult(
filter_dict={"grade": "初一"},
filter_mode=FilterMode.HARD,
)
pipeline_ctx = PipelineContext(
retrieval_ctx=retrieval_ctx,
metadata_filter=metadata_filter,
)
result = await pipeline.retrieve(pipeline_ctx)
assert result.metadata_filter_applied is True
call_args = mock_retriever.retrieve.call_args[0][0]
assert call_args.metadata_filter == {"grade": "初一"}
@pytest.mark.asyncio
async def test_health_check(self, pipeline, mock_retriever):
"""[AC-AISVC-RES-01] Should check health."""
result = await pipeline.health_check()
assert result is True
mock_retriever.health_check.assert_called_once()
class TestEnhancedPipeline:
"""[AC-AISVC-RES-02] Tests for enhanced pipeline."""
@pytest.fixture
def mock_qdrant_client(self):
"""Create a mock Qdrant client."""
client = AsyncMock()
client.search = AsyncMock(return_value=[
{"id": "1", "score": 0.9, "payload": {"text": "result 1"}},
])
client.get_client = AsyncMock()
return client
@pytest.fixture
def mock_embedding_provider(self):
"""Create a mock embedding provider."""
provider = AsyncMock()
provider.embed_query = AsyncMock()
provider.embed_query.return_value = MagicMock(
embedding_full=[0.1] * 768,
)
provider.embed = AsyncMock(return_value=[0.1] * 768)
return provider
@pytest.fixture
def pipeline(self, mock_qdrant_client, mock_embedding_provider):
"""Create an enhanced pipeline with mocks."""
pipeline = EnhancedPipeline(qdrant_client=mock_qdrant_client)
pipeline._embedding_provider = mock_embedding_provider
return pipeline
def test_pipeline_name(self, pipeline):
"""[AC-AISVC-RES-02] Pipeline should have correct name."""
assert pipeline.name == "enhanced_pipeline"
def test_pipeline_description(self, pipeline):
"""[AC-AISVC-RES-02] Pipeline should have description."""
assert "增强" in pipeline.description
@pytest.mark.asyncio
async def test_retrieve_basic(self, pipeline):
"""[AC-AISVC-RES-02] Should retrieve results using hybrid search."""
retrieval_ctx = RetrievalContext(
tenant_id="tenant_1",
query="test query",
)
pipeline_ctx = PipelineContext(retrieval_ctx=retrieval_ctx)
result = await pipeline.retrieve(pipeline_ctx)
assert result.pipeline_name == "enhanced_pipeline"
assert result.diagnostics is not None
class TestStrategyRouter:
"""[AC-AISVC-RES-01~03] Tests for strategy router."""
@pytest.fixture
def mock_default_pipeline(self):
"""Create a mock default pipeline."""
pipeline = AsyncMock(spec=DefaultPipeline)
pipeline.name = "default_pipeline"
pipeline.retrieve = AsyncMock(return_value=PipelineResult(
retrieval_result=RetrievalResult(hits=[]),
pipeline_name="default_pipeline",
))
return pipeline
@pytest.fixture
def mock_enhanced_pipeline(self):
"""Create a mock enhanced pipeline."""
pipeline = AsyncMock(spec=EnhancedPipeline)
pipeline.name = "enhanced_pipeline"
pipeline.retrieve = AsyncMock(return_value=PipelineResult(
retrieval_result=RetrievalResult(hits=[]),
pipeline_name="enhanced_pipeline",
))
return pipeline
@pytest.fixture
def router(self, mock_default_pipeline, mock_enhanced_pipeline):
"""Create a strategy router with mock pipelines."""
config = RetrievalStrategyConfig()
return StrategyRouter(
config=config,
default_pipeline=mock_default_pipeline,
enhanced_pipeline=mock_enhanced_pipeline,
)
def test_route_default_strategy(self, router):
"""[AC-AISVC-RES-01] Should route to default strategy by default."""
import asyncio
decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_1"))
assert decision.strategy == StrategyType.DEFAULT
assert decision.reason == "default_strategy"
def test_route_enhanced_strategy(self, mock_default_pipeline, mock_enhanced_pipeline):
"""[AC-AISVC-RES-02] Should route to enhanced strategy when configured."""
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
router = StrategyRouter(
config=config,
default_pipeline=mock_default_pipeline,
enhanced_pipeline=mock_enhanced_pipeline,
)
import asyncio
decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_1"))
assert decision.strategy == StrategyType.ENHANCED
assert decision.reason == "active_strategy=enhanced"
def test_route_grayscale_allowlist(self, mock_default_pipeline, mock_enhanced_pipeline):
"""[AC-AISVC-RES-03] Should route to enhanced for allowlist tenants."""
config = RetrievalStrategyConfig(
active_strategy=StrategyType.DEFAULT,
grayscale=GrayscaleConfig(enabled=True, allowlist=["tenant_a"]),
)
router = StrategyRouter(
config=config,
default_pipeline=mock_default_pipeline,
enhanced_pipeline=mock_enhanced_pipeline,
)
import asyncio
decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_a"))
assert decision.strategy == StrategyType.ENHANCED
assert decision.grayscale_hit is True
decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_b"))
assert decision.strategy == StrategyType.DEFAULT
def test_update_config(self, router):
"""[AC-AISVC-RES-02] Should update config."""
new_config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
router.update_config(new_config)
assert router.get_config().active_strategy == StrategyType.ENHANCED
class TestModeRouter:
"""[AC-AISVC-RES-09~15] Tests for mode router."""
@pytest.fixture
def router(self):
"""Create a mode router."""
return ModeRouter()
def test_decide_react_mode(self):
"""[AC-AISVC-RES-10] Should decide react when configured."""
config = ModeRouterConfig(runtime_mode=RuntimeMode.REACT)
router = ModeRouter(config)
decision = router.decide("any query")
assert decision.mode == RuntimeMode.REACT
assert decision.reason == "runtime_mode=react"
def test_decide_direct_mode(self, router):
"""[AC-AISVC-RES-09] Should decide direct when configured."""
decision = router.decide("any query")
assert decision.mode == RuntimeMode.DIRECT
assert decision.reason == "runtime_mode=direct"
def test_decide_auto_short_query_high_confidence(self):
"""[AC-AISVC-RES-12] Auto with short query and high confidence should use direct."""
config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO)
router = ModeRouter(config)
decision = router.decide("短问题", confidence=0.8)
assert decision.mode == RuntimeMode.DIRECT
def test_decide_auto_low_confidence(self):
"""[AC-AISVC-RES-13] Auto with low confidence should use react."""
config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO)
router = ModeRouter(config)
decision = router.decide("any query", confidence=0.3)
assert decision.mode == RuntimeMode.REACT
def test_should_fallback_to_react_empty_results(self, router):
"""[AC-AISVC-RES-14] Should fallback to react on empty results."""
result = PipelineResult(retrieval_result=RetrievalResult(hits=[]))
assert router.should_fallback_to_react(result) is True
def test_should_fallback_to_react_low_score(self, router):
"""[AC-AISVC-RES-14] Should fallback to react on low score."""
result = PipelineResult(
retrieval_result=RetrievalResult(
hits=[RetrievalHit(text="test", score=0.1, source="test", metadata={})],
),
)
assert router.should_fallback_to_react(result) is True
def test_should_not_fallback_to_react_disabled(self):
"""[AC-AISVC-RES-14] Should not fallback when disabled."""
config = ModeRouterConfig(direct_fallback_on_low_confidence=False)
router = ModeRouter(config)
result = PipelineResult(retrieval_result=RetrievalResult(hits=[]))
assert router.should_fallback_to_react(result) is False
class TestRollbackManager:
"""[AC-AISVC-RES-07] Tests for rollback manager."""
@pytest.fixture
def manager(self):
"""Create a rollback manager."""
return RollbackManager()
def test_rollback_from_enhanced(self, manager):
"""[AC-AISVC-RES-07] Should rollback from enhanced to default."""
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
manager.update_config(config)
result = manager.rollback(
trigger=RollbackTrigger.MANUAL,
reason="Testing rollback",
)
assert result.success is True
assert result.previous_strategy == StrategyType.ENHANCED
assert result.current_strategy == StrategyType.DEFAULT
assert result.audit_log is not None
def test_rollback_already_default(self, manager):
"""[AC-AISVC-RES-07] Should not rollback when already on default."""
result = manager.rollback(
trigger=RollbackTrigger.MANUAL,
reason="Testing rollback",
)
assert result.success is False
assert result.reason == "Already on default strategy"
def test_check_and_rollback_latency(self, manager):
"""[AC-AISVC-RES-08] Should rollback on high latency."""
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
manager.update_config(config)
result = manager.check_and_rollback(
metrics={"latency_ms": 3000.0},
tenant_id="tenant_1",
)
assert result is not None
assert result.trigger == RollbackTrigger.PERFORMANCE
def test_check_and_rollback_error_rate(self, manager):
"""[AC-AISVC-RES-08] Should rollback on high error rate."""
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
manager.update_config(config)
result = manager.check_and_rollback(
metrics={"error_rate": 0.1},
tenant_id="tenant_1",
)
assert result is not None
assert result.trigger == RollbackTrigger.ERROR
def test_check_and_rollback_ok(self, manager):
"""[AC-AISVC-RES-08] Should not rollback when metrics are ok."""
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
manager.update_config(config)
result = manager.check_and_rollback(
metrics={"latency_ms": 100.0, "error_rate": 0.01},
tenant_id="tenant_1",
)
assert result is None
def test_get_audit_logs(self, manager):
"""[AC-AISVC-RES-07] Should get audit logs."""
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
manager.update_config(config)
manager.rollback(trigger=RollbackTrigger.MANUAL, reason="Test")
logs = manager.get_audit_logs()
assert len(logs) == 1
assert logs[0].action == "rollback"
def test_record_audit(self, manager):
"""[AC-AISVC-RES-07] Should record audit log."""
log = manager.record_audit(
action="test_action",
details={"reason": "Testing"},
tenant_id="tenant_1",
)
assert log.action == "test_action"
assert log.tenant_id == "tenant_1"
class TestSingletonInstances:
"""Tests for singleton instance getters."""
def test_get_mode_router_singleton(self):
"""Should return same mode router instance."""
from app.services.retrieval.strategy.mode_router import _mode_router
import app.services.retrieval.strategy.mode_router as module
module._mode_router = None
router1 = get_mode_router()
router2 = get_mode_router()
assert router1 is router2
def test_get_rollback_manager_singleton(self):
"""Should return same rollback manager instance."""
from app.services.retrieval.strategy.rollback_manager import _rollback_manager
import app.services.retrieval.strategy.rollback_manager as module
module._rollback_manager = None
manager1 = get_rollback_manager()
manager2 = get_rollback_manager()
assert manager1 is manager2