ai-robot-core/ai-service/tests/test_retrieval_strategy.py

542 lines
20 KiB
Python

"""
Unit tests for Retrieval Strategy Service.
[AC-AISVC-RES-01~15] Tests for strategy management, switching, validation, and rollback.
"""
import pytest
from unittest.mock import MagicMock, patch
from datetime import datetime
from app.schemas.retrieval_strategy import (
ReactMode,
RolloutConfig,
RolloutMode,
StrategyType,
RetrievalStrategyStatus,
RetrievalStrategySwitchRequest,
RetrievalStrategyValidationRequest,
ValidationResult,
)
from app.services.retrieval.strategy_service import (
RetrievalStrategyService,
StrategyState,
get_strategy_service,
)
from app.services.retrieval.strategy_audit import (
StrategyAuditService,
get_audit_service,
)
from app.services.retrieval.strategy_metrics import (
StrategyMetricsService,
get_metrics_service,
)
class TestRetrievalStrategySchemas:
"""[AC-AISVC-RES-01~15] Tests for strategy schema models."""
def test_rollout_config_off_mode(self):
"""[AC-AISVC-RES-03] Off mode should not require percentage or allowlist."""
config = RolloutConfig(mode=RolloutMode.OFF)
assert config.mode == RolloutMode.OFF
assert config.percentage is None
assert config.allowlist is None
def test_rollout_config_percentage_mode(self):
"""[AC-AISVC-RES-03] Percentage mode should require percentage."""
config = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=50.0)
assert config.mode == RolloutMode.PERCENTAGE
assert config.percentage == 50.0
def test_rollout_config_percentage_mode_missing_value(self):
"""[AC-AISVC-RES-03] Percentage mode without percentage should raise error."""
with pytest.raises(ValueError, match="percentage is required"):
RolloutConfig(mode=RolloutMode.PERCENTAGE)
def test_rollout_config_allowlist_mode(self):
"""[AC-AISVC-RES-03] Allowlist mode should require allowlist."""
config = RolloutConfig(mode=RolloutMode.ALLOWLIST, allowlist=["tenant1", "tenant2"])
assert config.mode == RolloutMode.ALLOWLIST
assert config.allowlist == ["tenant1", "tenant2"]
def test_rollout_config_allowlist_mode_missing_value(self):
"""[AC-AISVC-RES-03] Allowlist mode without allowlist should raise error."""
with pytest.raises(ValueError, match="allowlist is required"):
RolloutConfig(mode=RolloutMode.ALLOWLIST)
def test_retrieval_strategy_status(self):
"""[AC-AISVC-RES-01] Status should contain all required fields."""
rollout = RolloutConfig(mode=RolloutMode.OFF)
status = RetrievalStrategyStatus(
active_strategy=StrategyType.DEFAULT,
react_mode=ReactMode.NON_REACT,
rollout=rollout,
)
assert status.active_strategy == StrategyType.DEFAULT
assert status.react_mode == ReactMode.NON_REACT
assert status.rollout.mode == RolloutMode.OFF
def test_switch_request_minimal(self):
"""[AC-AISVC-RES-02] Switch request should work with minimal fields."""
request = RetrievalStrategySwitchRequest(target_strategy=StrategyType.ENHANCED)
assert request.target_strategy == StrategyType.ENHANCED
assert request.react_mode is None
assert request.rollout is None
assert request.reason is None
def test_switch_request_full(self):
"""[AC-AISVC-RES-02,03,05] Switch request should accept all fields."""
rollout = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=30.0)
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
react_mode=ReactMode.REACT,
rollout=rollout,
reason="Testing enhanced strategy",
)
assert request.target_strategy == StrategyType.ENHANCED
assert request.react_mode == ReactMode.REACT
assert request.rollout.percentage == 30.0
assert request.reason == "Testing enhanced strategy"
class TestRetrievalStrategyService:
"""[AC-AISVC-RES-01~15] Tests for strategy service."""
@pytest.fixture
def service(self):
"""Create a fresh service instance for each test."""
return RetrievalStrategyService()
def test_get_current_status_default(self, service):
"""[AC-AISVC-RES-01] Default status should be default strategy and non_react mode."""
status = service.get_current_status()
assert status.active_strategy == StrategyType.DEFAULT
assert status.react_mode == ReactMode.NON_REACT
assert status.rollout.mode == RolloutMode.OFF
def test_switch_strategy_to_enhanced(self, service):
"""[AC-AISVC-RES-02] Should switch to enhanced strategy."""
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
react_mode=ReactMode.REACT,
)
response = service.switch_strategy(request)
assert response.previous.active_strategy == StrategyType.DEFAULT
assert response.current.active_strategy == StrategyType.ENHANCED
assert response.current.react_mode == ReactMode.REACT
def test_switch_strategy_with_grayscale_percentage(self, service):
"""[AC-AISVC-RES-03] Should switch with grayscale percentage."""
rollout = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=50.0)
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
rollout=rollout,
)
response = service.switch_strategy(request)
assert response.current.active_strategy == StrategyType.ENHANCED
assert response.current.rollout.mode == RolloutMode.PERCENTAGE
assert response.current.rollout.percentage == 50.0
def test_switch_strategy_with_allowlist(self, service):
"""[AC-AISVC-RES-03] Should switch with allowlist grayscale."""
rollout = RolloutConfig(
mode=RolloutMode.ALLOWLIST,
allowlist=["tenant_a", "tenant_b"],
)
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
rollout=rollout,
)
response = service.switch_strategy(request)
assert response.current.rollout.mode == RolloutMode.ALLOWLIST
assert "tenant_a" in response.current.rollout.allowlist
def test_rollback_strategy(self, service):
"""[AC-AISVC-RES-07] Should rollback to previous strategy."""
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
react_mode=ReactMode.REACT,
)
service.switch_strategy(request)
response = service.rollback_strategy()
assert response.rollback_to.active_strategy == StrategyType.DEFAULT
assert response.rollback_to.react_mode == ReactMode.NON_REACT
def test_rollback_without_previous_returns_default(self, service):
"""[AC-AISVC-RES-07] Rollback without previous should return default."""
response = service.rollback_strategy()
assert response.rollback_to.active_strategy == StrategyType.DEFAULT
def test_should_use_enhanced_strategy_default(self, service):
"""[AC-AISVC-RES-01] Default strategy should not use enhanced."""
assert service.should_use_enhanced_strategy("tenant_a") is False
def test_should_use_enhanced_strategy_with_allowlist(self, service):
"""[AC-AISVC-RES-03] Should use enhanced for tenants in allowlist."""
rollout = RolloutConfig(
mode=RolloutMode.ALLOWLIST,
allowlist=["tenant_a"],
)
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
rollout=rollout,
)
service.switch_strategy(request)
assert service.should_use_enhanced_strategy("tenant_a") is True
assert service.should_use_enhanced_strategy("tenant_b") is False
def test_get_route_mode_react(self, service):
"""[AC-AISVC-RES-10] React mode should return react route."""
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
react_mode=ReactMode.REACT,
)
service.switch_strategy(request)
route = service.get_route_mode("test query")
assert route == "react"
def test_get_route_mode_direct(self, service):
"""[AC-AISVC-RES-09] Non-react mode should return direct route."""
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.DEFAULT,
react_mode=ReactMode.NON_REACT,
)
service.switch_strategy(request)
route = service.get_route_mode("test query")
assert route == "direct"
def test_get_route_mode_auto_short_query(self, service):
"""[AC-AISVC-RES-12] Short query with high confidence should use direct route."""
service._state.react_mode = None
route = service._auto_route("短问题", confidence=0.8)
assert route == "direct"
def test_get_route_mode_auto_multiple_conditions(self, service):
"""[AC-AISVC-RES-13] Query with multiple conditions should use react route."""
route = service._auto_route("查询订单状态和物流信息")
assert route == "react"
def test_get_route_mode_auto_low_confidence(self, service):
"""[AC-AISVC-RES-13] Low confidence should use react route."""
route = service._auto_route("test query", confidence=0.3)
assert route == "react"
def test_get_switch_history(self, service):
"""Should track switch history."""
request = RetrievalStrategySwitchRequest(
target_strategy=StrategyType.ENHANCED,
reason="Testing",
)
service.switch_strategy(request)
history = service.get_switch_history()
assert len(history) == 1
assert history[0]["to_strategy"] == "enhanced"
class TestRetrievalStrategyValidation:
"""[AC-AISVC-RES-04,06,08] Tests for strategy validation."""
@pytest.fixture
def service(self):
return RetrievalStrategyService()
def test_validate_default_strategy(self, service):
"""[AC-AISVC-RES-06] Default strategy should pass validation."""
request = RetrievalStrategyValidationRequest(
strategy=StrategyType.DEFAULT,
)
response = service.validate_strategy(request)
assert response.passed is True
def test_validate_enhanced_strategy(self, service):
"""[AC-AISVC-RES-06] Enhanced strategy validation."""
request = RetrievalStrategyValidationRequest(
strategy=StrategyType.ENHANCED,
)
response = service.validate_strategy(request)
assert isinstance(response.passed, bool)
assert len(response.results) > 0
def test_validate_specific_checks(self, service):
"""[AC-AISVC-RES-06] Should run specific validation checks."""
request = RetrievalStrategyValidationRequest(
strategy=StrategyType.ENHANCED,
checks=["metadata_consistency", "performance_budget"],
)
response = service.validate_strategy(request)
check_names = [r.check for r in response.results]
assert "metadata_consistency" in check_names
assert "performance_budget" in check_names
def test_check_metadata_consistency(self, service):
"""[AC-AISVC-RES-04] Metadata consistency check."""
result = service._check_metadata_consistency(StrategyType.DEFAULT)
assert result.check == "metadata_consistency"
assert result.passed is True
def test_check_rrf_config(self, service):
"""[AC-AISVC-RES-02] RRF config check."""
result = service._check_rrf_config(StrategyType.DEFAULT)
assert result.check == "rrf_config"
assert isinstance(result.passed, bool)
def test_check_performance_budget(self, service):
"""[AC-AISVC-RES-08] Performance budget check."""
result = service._check_performance_budget(
StrategyType.ENHANCED,
ReactMode.REACT,
)
assert result.check == "performance_budget"
assert isinstance(result.passed, bool)
class TestStrategyAuditService:
"""[AC-AISVC-RES-07] Tests for audit service."""
@pytest.fixture
def audit_service(self):
return StrategyAuditService(max_entries=100)
def test_log_switch_operation(self, audit_service):
"""[AC-AISVC-RES-07] Should log switch operation."""
audit_service.log(
operation="switch",
previous_strategy="default",
new_strategy="enhanced",
reason="Testing",
operator="admin",
)
entries = audit_service.get_audit_log()
assert len(entries) == 1
assert entries[0].operation == "switch"
assert entries[0].previous_strategy == "default"
assert entries[0].new_strategy == "enhanced"
def test_log_rollback_operation(self, audit_service):
"""[AC-AISVC-RES-07] Should log rollback operation."""
audit_service.log_rollback(
previous_strategy="enhanced",
new_strategy="default",
reason="Performance issue",
operator="admin",
)
entries = audit_service.get_audit_log(operation="rollback")
assert len(entries) == 1
assert entries[0].operation == "rollback"
def test_log_validation_operation(self, audit_service):
"""[AC-AISVC-RES-06] Should log validation operation."""
audit_service.log_validation(
strategy="enhanced",
checks=["metadata_consistency"],
passed=True,
)
entries = audit_service.get_audit_log(operation="validate")
assert len(entries) == 1
assert entries[0].operation == "validate"
def test_get_audit_log_with_limit(self, audit_service):
"""Should limit audit log entries."""
for i in range(10):
audit_service.log(operation="switch", new_strategy=f"strategy_{i}")
entries = audit_service.get_audit_log(limit=5)
assert len(entries) == 5
def test_get_audit_stats(self, audit_service):
"""Should return audit statistics."""
audit_service.log(operation="switch", new_strategy="enhanced")
audit_service.log(operation="rollback", new_strategy="default")
stats = audit_service.get_audit_stats()
assert stats["total_entries"] == 2
assert stats["operation_counts"]["switch"] == 1
assert stats["operation_counts"]["rollback"] == 1
def test_clear_audit_log(self, audit_service):
"""Should clear audit log."""
audit_service.log(operation="switch", new_strategy="enhanced")
assert len(audit_service.get_audit_log()) == 1
count = audit_service.clear_audit_log()
assert count == 1
assert len(audit_service.get_audit_log()) == 0
class TestStrategyMetricsService:
"""[AC-AISVC-RES-03,08] Tests for metrics service."""
@pytest.fixture
def metrics_service(self):
return StrategyMetricsService()
def test_record_request(self, metrics_service):
"""[AC-AISVC-RES-08] Should record request metrics."""
metrics_service.record_request(
latency_ms=100.0,
success=True,
route_mode="direct",
)
metrics = metrics_service.get_metrics()
assert metrics.total_requests == 1
assert metrics.successful_requests == 1
assert metrics.avg_latency_ms == 100.0
def test_record_failed_request(self, metrics_service):
"""[AC-AISVC-RES-08] Should record failed request."""
metrics_service.record_request(latency_ms=50.0, success=False)
metrics = metrics_service.get_metrics()
assert metrics.failed_requests == 1
def test_record_fallback(self, metrics_service):
"""[AC-AISVC-RES-08] Should record fallback count."""
metrics_service.record_request(
latency_ms=100.0,
success=True,
fallback=True,
)
metrics = metrics_service.get_metrics()
assert metrics.fallback_count == 1
def test_record_route_metrics(self, metrics_service):
"""[AC-AISVC-RES-08] Should track route mode metrics."""
metrics_service.record_request(latency_ms=100.0, success=True, route_mode="react")
metrics_service.record_request(latency_ms=50.0, success=True, route_mode="direct")
route_metrics = metrics_service.get_route_metrics()
assert "react" in route_metrics
assert "direct" in route_metrics
def test_get_all_metrics(self, metrics_service):
"""Should get metrics for all strategies."""
metrics_service.set_current_strategy(StrategyType.ENHANCED, ReactMode.REACT)
metrics_service.record_request(latency_ms=100.0, success=True)
all_metrics = metrics_service.get_all_metrics()
assert StrategyType.DEFAULT.value in all_metrics
assert StrategyType.ENHANCED.value in all_metrics
def test_get_performance_summary(self, metrics_service):
"""[AC-AISVC-RES-08] Should get performance summary."""
metrics_service.record_request(latency_ms=100.0, success=True)
metrics_service.record_request(latency_ms=200.0, success=True)
metrics_service.record_request(latency_ms=50.0, success=False)
summary = metrics_service.get_performance_summary()
assert summary["total_requests"] == 3
assert summary["successful_requests"] == 2
assert summary["failed_requests"] == 1
assert summary["success_rate"] == pytest.approx(0.6667, rel=0.01)
def test_check_performance_threshold_ok(self, metrics_service):
"""[AC-AISVC-RES-08] Should pass performance threshold check."""
metrics_service.record_request(latency_ms=100.0, success=True)
result = metrics_service.check_performance_threshold(
strategy=StrategyType.DEFAULT,
max_latency_ms=5000.0,
max_error_rate=0.1,
)
assert result["latency_ok"] is True
assert result["error_rate_ok"] is True
assert result["overall_ok"] is True
def test_check_performance_threshold_exceeded(self, metrics_service):
"""[AC-AISVC-RES-08] Should fail when threshold exceeded."""
metrics_service.record_request(latency_ms=6000.0, success=True)
metrics_service.record_request(latency_ms=100.0, success=False)
result = metrics_service.check_performance_threshold(
strategy=StrategyType.DEFAULT,
max_latency_ms=5000.0,
max_error_rate=0.1,
)
assert result["latency_ok"] is False or result["error_rate_ok"] is False
def test_reset_metrics(self, metrics_service):
"""Should reset metrics."""
metrics_service.record_request(latency_ms=100.0, success=True)
metrics_service.reset_metrics()
metrics = metrics_service.get_metrics()
assert metrics.total_requests == 0
class TestSingletonInstances:
"""Tests for singleton instance getters."""
def test_get_strategy_service_singleton(self):
"""Should return same strategy service instance."""
from app.services.retrieval.strategy_service import _strategy_service
import app.services.retrieval.strategy_service as module
module._strategy_service = None
service1 = get_strategy_service()
service2 = get_strategy_service()
assert service1 is service2
def test_get_audit_service_singleton(self):
"""Should return same audit service instance."""
from app.services.retrieval.strategy_audit import _audit_service
import app.services.retrieval.strategy_audit as module
module._audit_service = None
service1 = get_audit_service()
service2 = get_audit_service()
assert service1 is service2
def test_get_metrics_service_singleton(self):
"""Should return same metrics service instance."""
from app.services.retrieval.strategy_metrics import _metrics_service
import app.services.retrieval.strategy_metrics as module
module._metrics_service = None
service1 = get_metrics_service()
service2 = get_metrics_service()
assert service1 is service2