646 lines
24 KiB
Python
646 lines
24 KiB
Python
"""
|
|
Unit tests for Retrieval Strategy Module.
|
|
[AC-AISVC-RES-01~15] Tests for strategy config, pipelines, routers, and rollback.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from dataclasses import asdict
|
|
|
|
from app.services.retrieval.strategy.config import (
|
|
FilterMode,
|
|
GrayscaleConfig,
|
|
HybridRetrievalConfig,
|
|
MetadataInferenceConfig,
|
|
ModeRouterConfig,
|
|
PipelineConfig,
|
|
RerankerConfig,
|
|
RetrievalStrategyConfig,
|
|
RuntimeMode,
|
|
StrategyType,
|
|
get_strategy_config,
|
|
set_strategy_config,
|
|
)
|
|
from app.services.retrieval.strategy.pipeline_base import (
|
|
BasePipeline,
|
|
MetadataFilterResult,
|
|
PipelineContext,
|
|
PipelineResult,
|
|
)
|
|
from app.services.retrieval.strategy.default_pipeline import DefaultPipeline
|
|
from app.services.retrieval.strategy.enhanced_pipeline import EnhancedPipeline
|
|
from app.services.retrieval.strategy.strategy_router import (
|
|
RoutingDecision,
|
|
StrategyRouter,
|
|
get_strategy_router,
|
|
)
|
|
from app.services.retrieval.strategy.mode_router import (
|
|
ModeDecision,
|
|
ModeRouter,
|
|
get_mode_router,
|
|
)
|
|
from app.services.retrieval.strategy.rollback_manager import (
|
|
AuditLog,
|
|
RollbackManager,
|
|
RollbackResult,
|
|
RollbackTrigger,
|
|
get_rollback_manager,
|
|
)
|
|
from app.services.retrieval.base import RetrievalContext, RetrievalHit, RetrievalResult
|
|
|
|
|
|
class TestStrategyConfig:
|
|
"""[AC-AISVC-RES-01~15] Tests for strategy configuration models."""
|
|
|
|
def test_strategy_type_enum(self):
|
|
"""[AC-AISVC-RES-01] Strategy type should have default and enhanced values."""
|
|
assert StrategyType.DEFAULT.value == "default"
|
|
assert StrategyType.ENHANCED.value == "enhanced"
|
|
|
|
def test_runtime_mode_enum(self):
|
|
"""[AC-AISVC-RES-09] Runtime mode should have direct, react, and auto values."""
|
|
assert RuntimeMode.DIRECT.value == "direct"
|
|
assert RuntimeMode.REACT.value == "react"
|
|
assert RuntimeMode.AUTO.value == "auto"
|
|
|
|
def test_filter_mode_enum(self):
|
|
"""[AC-AISVC-RES-04] Filter mode should have hard, soft, and none values."""
|
|
assert FilterMode.HARD.value == "hard"
|
|
assert FilterMode.SOFT.value == "soft"
|
|
assert FilterMode.NONE.value == "none"
|
|
|
|
def test_grayscale_config_default(self):
|
|
"""[AC-AISVC-RES-03] Default grayscale config should be disabled."""
|
|
config = GrayscaleConfig()
|
|
assert config.enabled is False
|
|
assert config.percentage == 0.0
|
|
assert config.allowlist == []
|
|
|
|
def test_grayscale_config_should_use_enhanced_disabled(self):
|
|
"""[AC-AISVC-RES-03] Should not use enhanced when grayscale disabled."""
|
|
config = GrayscaleConfig(enabled=False, percentage=50.0)
|
|
assert config.should_use_enhanced("tenant_a") is False
|
|
|
|
def test_grayscale_config_should_use_enhanced_allowlist(self):
|
|
"""[AC-AISVC-RES-03] Should use enhanced for tenants in allowlist."""
|
|
config = GrayscaleConfig(enabled=True, allowlist=["tenant_a", "tenant_b"])
|
|
assert config.should_use_enhanced("tenant_a") is True
|
|
assert config.should_use_enhanced("tenant_b") is True
|
|
assert config.should_use_enhanced("tenant_c") is False
|
|
|
|
def test_grayscale_config_should_use_enhanced_percentage(self):
|
|
"""[AC-AISVC-RES-03] Should use enhanced based on percentage."""
|
|
config = GrayscaleConfig(enabled=True, percentage=100.0)
|
|
assert config.should_use_enhanced("any_tenant") is True
|
|
|
|
config = GrayscaleConfig(enabled=True, percentage=0.0)
|
|
assert config.should_use_enhanced("any_tenant") is False
|
|
|
|
def test_reranker_config_default(self):
|
|
"""[AC-AISVC-RES-08] Default reranker config should be disabled."""
|
|
config = RerankerConfig()
|
|
assert config.enabled is False
|
|
assert config.model == "cross-encoder"
|
|
assert config.top_k_after_rerank == 5
|
|
|
|
def test_mode_router_config_default(self):
|
|
"""[AC-AISVC-RES-09] Default mode router config should be direct."""
|
|
config = ModeRouterConfig()
|
|
assert config.runtime_mode == RuntimeMode.DIRECT
|
|
assert config.react_trigger_confidence_threshold == 0.6
|
|
assert config.react_max_steps == 5
|
|
|
|
def test_mode_router_config_should_use_react_always(self):
|
|
"""[AC-AISVC-RES-10] React mode should always use react."""
|
|
config = ModeRouterConfig(runtime_mode=RuntimeMode.REACT)
|
|
assert config.should_use_react("any query") is True
|
|
|
|
def test_mode_router_config_should_use_react_never(self):
|
|
"""[AC-AISVC-RES-09] Direct mode should never use react."""
|
|
config = ModeRouterConfig(runtime_mode=RuntimeMode.DIRECT)
|
|
assert config.should_use_react("any query") is False
|
|
|
|
def test_mode_router_config_auto_short_query_high_confidence(self):
|
|
"""[AC-AISVC-RES-12] Auto mode with short query and high confidence should use direct."""
|
|
config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO)
|
|
assert config.should_use_react("短问题", confidence=0.8) is False
|
|
|
|
def test_mode_router_config_auto_low_confidence(self):
|
|
"""[AC-AISVC-RES-13] Auto mode with low confidence should use react."""
|
|
config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO)
|
|
assert config.should_use_react("any query", confidence=0.3) is True
|
|
|
|
def test_metadata_inference_config_determine_filter_mode(self):
|
|
"""[AC-AISVC-RES-04] Should determine filter mode based on confidence."""
|
|
config = MetadataInferenceConfig()
|
|
|
|
assert config.determine_filter_mode(0.9) == FilterMode.HARD
|
|
assert config.determine_filter_mode(0.6) == FilterMode.SOFT
|
|
assert config.determine_filter_mode(0.3) == FilterMode.NONE
|
|
assert config.determine_filter_mode(None) == FilterMode.NONE
|
|
|
|
def test_pipeline_config_default(self):
|
|
"""[AC-AISVC-RES-01] Default pipeline config should have sensible defaults."""
|
|
config = PipelineConfig()
|
|
assert config.top_k == 5
|
|
assert config.score_threshold == 0.01
|
|
assert config.two_stage_enabled is True
|
|
|
|
def test_retrieval_strategy_config_default(self):
|
|
"""[AC-AISVC-RES-01] Default strategy config should use default strategy."""
|
|
config = RetrievalStrategyConfig()
|
|
assert config.active_strategy == StrategyType.DEFAULT
|
|
assert config.grayscale.enabled is False
|
|
assert config.mode_router.runtime_mode == RuntimeMode.DIRECT
|
|
|
|
def test_retrieval_strategy_config_is_enhanced_enabled(self):
|
|
"""[AC-AISVC-RES-02] Should check if enhanced is enabled."""
|
|
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
|
assert config.is_enhanced_enabled("tenant_a") is True
|
|
|
|
config = RetrievalStrategyConfig(
|
|
active_strategy=StrategyType.DEFAULT,
|
|
grayscale=GrayscaleConfig(enabled=True, allowlist=["tenant_a"]),
|
|
)
|
|
assert config.is_enhanced_enabled("tenant_a") is True
|
|
assert config.is_enhanced_enabled("tenant_b") is False
|
|
|
|
def test_retrieval_strategy_config_to_dict(self):
|
|
"""[AC-AISVC-RES-01] Should convert config to dictionary."""
|
|
config = RetrievalStrategyConfig()
|
|
d = config.to_dict()
|
|
|
|
assert d["active_strategy"] == "default"
|
|
assert "grayscale" in d
|
|
assert "pipeline" in d
|
|
assert "reranker" in d
|
|
assert "mode_router" in d
|
|
|
|
def test_global_config_functions(self):
|
|
"""[AC-AISVC-RES-01] Should get and set global config."""
|
|
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
|
set_strategy_config(config)
|
|
|
|
retrieved = get_strategy_config()
|
|
assert retrieved.active_strategy == StrategyType.ENHANCED
|
|
|
|
set_strategy_config(RetrievalStrategyConfig())
|
|
|
|
|
|
class TestPipelineBase:
|
|
"""[AC-AISVC-RES-01~02] Tests for pipeline base classes."""
|
|
|
|
def test_metadata_filter_result_default(self):
|
|
"""[AC-AISVC-RES-04] Default metadata filter result should be empty."""
|
|
result = MetadataFilterResult()
|
|
assert result.filter_dict == {}
|
|
assert result.filter_mode == FilterMode.NONE
|
|
assert result.confidence is None
|
|
|
|
def test_pipeline_context_properties(self):
|
|
"""[AC-AISVC-RES-01] Pipeline context should expose retrieval context properties."""
|
|
retrieval_ctx = RetrievalContext(
|
|
tenant_id="tenant_1",
|
|
query="test query",
|
|
session_id="session_1",
|
|
kb_ids=["kb_1"],
|
|
)
|
|
pipeline_ctx = PipelineContext(retrieval_ctx=retrieval_ctx)
|
|
|
|
assert pipeline_ctx.tenant_id == "tenant_1"
|
|
assert pipeline_ctx.query == "test query"
|
|
assert pipeline_ctx.session_id == "session_1"
|
|
assert pipeline_ctx.kb_ids == ["kb_1"]
|
|
|
|
def test_pipeline_result_properties(self):
|
|
"""[AC-AISVC-RES-01] Pipeline result should expose retrieval result properties."""
|
|
hits = [
|
|
RetrievalHit(text="hit 1", score=0.9, source="test", metadata={}),
|
|
RetrievalHit(text="hit 2", score=0.8, source="test", metadata={}),
|
|
]
|
|
retrieval_result = RetrievalResult(hits=hits)
|
|
pipeline_result = PipelineResult(
|
|
retrieval_result=retrieval_result,
|
|
pipeline_name="test_pipeline",
|
|
)
|
|
|
|
assert pipeline_result.hits == hits
|
|
assert pipeline_result.is_empty is False
|
|
assert pipeline_result.pipeline_name == "test_pipeline"
|
|
|
|
def test_pipeline_result_is_empty(self):
|
|
"""[AC-AISVC-RES-01] Pipeline result should detect empty results."""
|
|
pipeline_result = PipelineResult(
|
|
retrieval_result=RetrievalResult(hits=[]),
|
|
)
|
|
assert pipeline_result.is_empty is True
|
|
|
|
|
|
class TestDefaultPipeline:
|
|
"""[AC-AISVC-RES-01] Tests for default pipeline."""
|
|
|
|
@pytest.fixture
|
|
def mock_retriever(self):
|
|
"""Create a mock optimized retriever."""
|
|
retriever = AsyncMock()
|
|
retriever.retrieve = AsyncMock(return_value=RetrievalResult(
|
|
hits=[
|
|
RetrievalHit(text="result 1", score=0.9, source="default", metadata={}),
|
|
],
|
|
diagnostics={"test": True},
|
|
))
|
|
retriever.health_check = AsyncMock(return_value=True)
|
|
retriever._two_stage_enabled = True
|
|
retriever._hybrid_enabled = True
|
|
return retriever
|
|
|
|
@pytest.fixture
|
|
def pipeline(self, mock_retriever):
|
|
"""Create a default pipeline with mock retriever."""
|
|
return DefaultPipeline(optimized_retriever=mock_retriever)
|
|
|
|
def test_pipeline_name(self, pipeline):
|
|
"""[AC-AISVC-RES-01] Pipeline should have correct name."""
|
|
assert pipeline.name == "default_pipeline"
|
|
|
|
def test_pipeline_description(self, pipeline):
|
|
"""[AC-AISVC-RES-01] Pipeline should have description."""
|
|
assert "默认" in pipeline.description
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve(self, pipeline, mock_retriever):
|
|
"""[AC-AISVC-RES-01] Should retrieve results using optimized retriever."""
|
|
retrieval_ctx = RetrievalContext(
|
|
tenant_id="tenant_1",
|
|
query="test query",
|
|
)
|
|
pipeline_ctx = PipelineContext(retrieval_ctx=retrieval_ctx)
|
|
|
|
result = await pipeline.retrieve(pipeline_ctx)
|
|
|
|
assert result.pipeline_name == "default_pipeline"
|
|
assert len(result.hits) == 1
|
|
assert result.diagnostics["retriever"] == "OptimizedRetriever"
|
|
mock_retriever.retrieve.assert_called_once()
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_with_metadata_filter(self, pipeline, mock_retriever):
|
|
"""[AC-AISVC-RES-04] Should apply metadata filter."""
|
|
retrieval_ctx = RetrievalContext(
|
|
tenant_id="tenant_1",
|
|
query="test query",
|
|
)
|
|
metadata_filter = MetadataFilterResult(
|
|
filter_dict={"grade": "初一"},
|
|
filter_mode=FilterMode.HARD,
|
|
)
|
|
pipeline_ctx = PipelineContext(
|
|
retrieval_ctx=retrieval_ctx,
|
|
metadata_filter=metadata_filter,
|
|
)
|
|
|
|
result = await pipeline.retrieve(pipeline_ctx)
|
|
|
|
assert result.metadata_filter_applied is True
|
|
call_args = mock_retriever.retrieve.call_args[0][0]
|
|
assert call_args.metadata_filter == {"grade": "初一"}
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_health_check(self, pipeline, mock_retriever):
|
|
"""[AC-AISVC-RES-01] Should check health."""
|
|
result = await pipeline.health_check()
|
|
assert result is True
|
|
mock_retriever.health_check.assert_called_once()
|
|
|
|
|
|
class TestEnhancedPipeline:
|
|
"""[AC-AISVC-RES-02] Tests for enhanced pipeline."""
|
|
|
|
@pytest.fixture
|
|
def mock_qdrant_client(self):
|
|
"""Create a mock Qdrant client."""
|
|
client = AsyncMock()
|
|
client.search = AsyncMock(return_value=[
|
|
{"id": "1", "score": 0.9, "payload": {"text": "result 1"}},
|
|
])
|
|
client.get_client = AsyncMock()
|
|
return client
|
|
|
|
@pytest.fixture
|
|
def mock_embedding_provider(self):
|
|
"""Create a mock embedding provider."""
|
|
provider = AsyncMock()
|
|
provider.embed_query = AsyncMock()
|
|
provider.embed_query.return_value = MagicMock(
|
|
embedding_full=[0.1] * 768,
|
|
)
|
|
provider.embed = AsyncMock(return_value=[0.1] * 768)
|
|
return provider
|
|
|
|
@pytest.fixture
|
|
def pipeline(self, mock_qdrant_client, mock_embedding_provider):
|
|
"""Create an enhanced pipeline with mocks."""
|
|
pipeline = EnhancedPipeline(qdrant_client=mock_qdrant_client)
|
|
pipeline._embedding_provider = mock_embedding_provider
|
|
return pipeline
|
|
|
|
def test_pipeline_name(self, pipeline):
|
|
"""[AC-AISVC-RES-02] Pipeline should have correct name."""
|
|
assert pipeline.name == "enhanced_pipeline"
|
|
|
|
def test_pipeline_description(self, pipeline):
|
|
"""[AC-AISVC-RES-02] Pipeline should have description."""
|
|
assert "增强" in pipeline.description
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_retrieve_basic(self, pipeline):
|
|
"""[AC-AISVC-RES-02] Should retrieve results using hybrid search."""
|
|
retrieval_ctx = RetrievalContext(
|
|
tenant_id="tenant_1",
|
|
query="test query",
|
|
)
|
|
pipeline_ctx = PipelineContext(retrieval_ctx=retrieval_ctx)
|
|
|
|
result = await pipeline.retrieve(pipeline_ctx)
|
|
|
|
assert result.pipeline_name == "enhanced_pipeline"
|
|
assert result.diagnostics is not None
|
|
|
|
|
|
class TestStrategyRouter:
|
|
"""[AC-AISVC-RES-01~03] Tests for strategy router."""
|
|
|
|
@pytest.fixture
|
|
def mock_default_pipeline(self):
|
|
"""Create a mock default pipeline."""
|
|
pipeline = AsyncMock(spec=DefaultPipeline)
|
|
pipeline.name = "default_pipeline"
|
|
pipeline.retrieve = AsyncMock(return_value=PipelineResult(
|
|
retrieval_result=RetrievalResult(hits=[]),
|
|
pipeline_name="default_pipeline",
|
|
))
|
|
return pipeline
|
|
|
|
@pytest.fixture
|
|
def mock_enhanced_pipeline(self):
|
|
"""Create a mock enhanced pipeline."""
|
|
pipeline = AsyncMock(spec=EnhancedPipeline)
|
|
pipeline.name = "enhanced_pipeline"
|
|
pipeline.retrieve = AsyncMock(return_value=PipelineResult(
|
|
retrieval_result=RetrievalResult(hits=[]),
|
|
pipeline_name="enhanced_pipeline",
|
|
))
|
|
return pipeline
|
|
|
|
@pytest.fixture
|
|
def router(self, mock_default_pipeline, mock_enhanced_pipeline):
|
|
"""Create a strategy router with mock pipelines."""
|
|
config = RetrievalStrategyConfig()
|
|
return StrategyRouter(
|
|
config=config,
|
|
default_pipeline=mock_default_pipeline,
|
|
enhanced_pipeline=mock_enhanced_pipeline,
|
|
)
|
|
|
|
def test_route_default_strategy(self, router):
|
|
"""[AC-AISVC-RES-01] Should route to default strategy by default."""
|
|
import asyncio
|
|
decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_1"))
|
|
|
|
assert decision.strategy == StrategyType.DEFAULT
|
|
assert decision.reason == "default_strategy"
|
|
|
|
def test_route_enhanced_strategy(self, mock_default_pipeline, mock_enhanced_pipeline):
|
|
"""[AC-AISVC-RES-02] Should route to enhanced strategy when configured."""
|
|
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
|
router = StrategyRouter(
|
|
config=config,
|
|
default_pipeline=mock_default_pipeline,
|
|
enhanced_pipeline=mock_enhanced_pipeline,
|
|
)
|
|
|
|
import asyncio
|
|
decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_1"))
|
|
|
|
assert decision.strategy == StrategyType.ENHANCED
|
|
assert decision.reason == "active_strategy=enhanced"
|
|
|
|
def test_route_grayscale_allowlist(self, mock_default_pipeline, mock_enhanced_pipeline):
|
|
"""[AC-AISVC-RES-03] Should route to enhanced for allowlist tenants."""
|
|
config = RetrievalStrategyConfig(
|
|
active_strategy=StrategyType.DEFAULT,
|
|
grayscale=GrayscaleConfig(enabled=True, allowlist=["tenant_a"]),
|
|
)
|
|
router = StrategyRouter(
|
|
config=config,
|
|
default_pipeline=mock_default_pipeline,
|
|
enhanced_pipeline=mock_enhanced_pipeline,
|
|
)
|
|
|
|
import asyncio
|
|
|
|
decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_a"))
|
|
assert decision.strategy == StrategyType.ENHANCED
|
|
assert decision.grayscale_hit is True
|
|
|
|
decision = asyncio.get_event_loop().run_until_complete(router.route("tenant_b"))
|
|
assert decision.strategy == StrategyType.DEFAULT
|
|
|
|
def test_update_config(self, router):
|
|
"""[AC-AISVC-RES-02] Should update config."""
|
|
new_config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
|
router.update_config(new_config)
|
|
|
|
assert router.get_config().active_strategy == StrategyType.ENHANCED
|
|
|
|
|
|
class TestModeRouter:
|
|
"""[AC-AISVC-RES-09~15] Tests for mode router."""
|
|
|
|
@pytest.fixture
|
|
def router(self):
|
|
"""Create a mode router."""
|
|
return ModeRouter()
|
|
|
|
def test_decide_react_mode(self):
|
|
"""[AC-AISVC-RES-10] Should decide react when configured."""
|
|
config = ModeRouterConfig(runtime_mode=RuntimeMode.REACT)
|
|
router = ModeRouter(config)
|
|
|
|
decision = router.decide("any query")
|
|
|
|
assert decision.mode == RuntimeMode.REACT
|
|
assert decision.reason == "runtime_mode=react"
|
|
|
|
def test_decide_direct_mode(self, router):
|
|
"""[AC-AISVC-RES-09] Should decide direct when configured."""
|
|
decision = router.decide("any query")
|
|
|
|
assert decision.mode == RuntimeMode.DIRECT
|
|
assert decision.reason == "runtime_mode=direct"
|
|
|
|
def test_decide_auto_short_query_high_confidence(self):
|
|
"""[AC-AISVC-RES-12] Auto with short query and high confidence should use direct."""
|
|
config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO)
|
|
router = ModeRouter(config)
|
|
|
|
decision = router.decide("短问题", confidence=0.8)
|
|
|
|
assert decision.mode == RuntimeMode.DIRECT
|
|
|
|
def test_decide_auto_low_confidence(self):
|
|
"""[AC-AISVC-RES-13] Auto with low confidence should use react."""
|
|
config = ModeRouterConfig(runtime_mode=RuntimeMode.AUTO)
|
|
router = ModeRouter(config)
|
|
|
|
decision = router.decide("any query", confidence=0.3)
|
|
|
|
assert decision.mode == RuntimeMode.REACT
|
|
|
|
def test_should_fallback_to_react_empty_results(self, router):
|
|
"""[AC-AISVC-RES-14] Should fallback to react on empty results."""
|
|
result = PipelineResult(retrieval_result=RetrievalResult(hits=[]))
|
|
|
|
assert router.should_fallback_to_react(result) is True
|
|
|
|
def test_should_fallback_to_react_low_score(self, router):
|
|
"""[AC-AISVC-RES-14] Should fallback to react on low score."""
|
|
result = PipelineResult(
|
|
retrieval_result=RetrievalResult(
|
|
hits=[RetrievalHit(text="test", score=0.1, source="test", metadata={})],
|
|
),
|
|
)
|
|
|
|
assert router.should_fallback_to_react(result) is True
|
|
|
|
def test_should_not_fallback_to_react_disabled(self):
|
|
"""[AC-AISVC-RES-14] Should not fallback when disabled."""
|
|
config = ModeRouterConfig(direct_fallback_on_low_confidence=False)
|
|
router = ModeRouter(config)
|
|
|
|
result = PipelineResult(retrieval_result=RetrievalResult(hits=[]))
|
|
|
|
assert router.should_fallback_to_react(result) is False
|
|
|
|
|
|
class TestRollbackManager:
|
|
"""[AC-AISVC-RES-07] Tests for rollback manager."""
|
|
|
|
@pytest.fixture
|
|
def manager(self):
|
|
"""Create a rollback manager."""
|
|
return RollbackManager()
|
|
|
|
def test_rollback_from_enhanced(self, manager):
|
|
"""[AC-AISVC-RES-07] Should rollback from enhanced to default."""
|
|
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
|
manager.update_config(config)
|
|
|
|
result = manager.rollback(
|
|
trigger=RollbackTrigger.MANUAL,
|
|
reason="Testing rollback",
|
|
)
|
|
|
|
assert result.success is True
|
|
assert result.previous_strategy == StrategyType.ENHANCED
|
|
assert result.current_strategy == StrategyType.DEFAULT
|
|
assert result.audit_log is not None
|
|
|
|
def test_rollback_already_default(self, manager):
|
|
"""[AC-AISVC-RES-07] Should not rollback when already on default."""
|
|
result = manager.rollback(
|
|
trigger=RollbackTrigger.MANUAL,
|
|
reason="Testing rollback",
|
|
)
|
|
|
|
assert result.success is False
|
|
assert result.reason == "Already on default strategy"
|
|
|
|
def test_check_and_rollback_latency(self, manager):
|
|
"""[AC-AISVC-RES-08] Should rollback on high latency."""
|
|
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
|
manager.update_config(config)
|
|
|
|
result = manager.check_and_rollback(
|
|
metrics={"latency_ms": 3000.0},
|
|
tenant_id="tenant_1",
|
|
)
|
|
|
|
assert result is not None
|
|
assert result.trigger == RollbackTrigger.PERFORMANCE
|
|
|
|
def test_check_and_rollback_error_rate(self, manager):
|
|
"""[AC-AISVC-RES-08] Should rollback on high error rate."""
|
|
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
|
manager.update_config(config)
|
|
|
|
result = manager.check_and_rollback(
|
|
metrics={"error_rate": 0.1},
|
|
tenant_id="tenant_1",
|
|
)
|
|
|
|
assert result is not None
|
|
assert result.trigger == RollbackTrigger.ERROR
|
|
|
|
def test_check_and_rollback_ok(self, manager):
|
|
"""[AC-AISVC-RES-08] Should not rollback when metrics are ok."""
|
|
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
|
manager.update_config(config)
|
|
|
|
result = manager.check_and_rollback(
|
|
metrics={"latency_ms": 100.0, "error_rate": 0.01},
|
|
tenant_id="tenant_1",
|
|
)
|
|
|
|
assert result is None
|
|
|
|
def test_get_audit_logs(self, manager):
|
|
"""[AC-AISVC-RES-07] Should get audit logs."""
|
|
config = RetrievalStrategyConfig(active_strategy=StrategyType.ENHANCED)
|
|
manager.update_config(config)
|
|
manager.rollback(trigger=RollbackTrigger.MANUAL, reason="Test")
|
|
|
|
logs = manager.get_audit_logs()
|
|
|
|
assert len(logs) == 1
|
|
assert logs[0].action == "rollback"
|
|
|
|
def test_record_audit(self, manager):
|
|
"""[AC-AISVC-RES-07] Should record audit log."""
|
|
log = manager.record_audit(
|
|
action="test_action",
|
|
details={"reason": "Testing"},
|
|
tenant_id="tenant_1",
|
|
)
|
|
|
|
assert log.action == "test_action"
|
|
assert log.tenant_id == "tenant_1"
|
|
|
|
|
|
class TestSingletonInstances:
|
|
"""Tests for singleton instance getters."""
|
|
|
|
def test_get_mode_router_singleton(self):
|
|
"""Should return same mode router instance."""
|
|
from app.services.retrieval.strategy.mode_router import _mode_router
|
|
|
|
import app.services.retrieval.strategy.mode_router as module
|
|
module._mode_router = None
|
|
|
|
router1 = get_mode_router()
|
|
router2 = get_mode_router()
|
|
|
|
assert router1 is router2
|
|
|
|
def test_get_rollback_manager_singleton(self):
|
|
"""Should return same rollback manager instance."""
|
|
from app.services.retrieval.strategy.rollback_manager import _rollback_manager
|
|
|
|
import app.services.retrieval.strategy.rollback_manager as module
|
|
module._rollback_manager = None
|
|
|
|
manager1 = get_rollback_manager()
|
|
manager2 = get_rollback_manager()
|
|
|
|
assert manager1 is manager2
|