542 lines
20 KiB
Python
542 lines
20 KiB
Python
"""
|
|
Unit tests for Retrieval Strategy Service.
|
|
[AC-AISVC-RES-01~15] Tests for strategy management, switching, validation, and rollback.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import MagicMock, patch
|
|
from datetime import datetime
|
|
|
|
from app.schemas.retrieval_strategy import (
|
|
ReactMode,
|
|
RolloutConfig,
|
|
RolloutMode,
|
|
StrategyType,
|
|
RetrievalStrategyStatus,
|
|
RetrievalStrategySwitchRequest,
|
|
RetrievalStrategyValidationRequest,
|
|
ValidationResult,
|
|
)
|
|
from app.services.retrieval.strategy_service import (
|
|
RetrievalStrategyService,
|
|
StrategyState,
|
|
get_strategy_service,
|
|
)
|
|
from app.services.retrieval.strategy_audit import (
|
|
StrategyAuditService,
|
|
get_audit_service,
|
|
)
|
|
from app.services.retrieval.strategy_metrics import (
|
|
StrategyMetricsService,
|
|
get_metrics_service,
|
|
)
|
|
|
|
|
|
class TestRetrievalStrategySchemas:
|
|
"""[AC-AISVC-RES-01~15] Tests for strategy schema models."""
|
|
|
|
def test_rollout_config_off_mode(self):
|
|
"""[AC-AISVC-RES-03] Off mode should not require percentage or allowlist."""
|
|
config = RolloutConfig(mode=RolloutMode.OFF)
|
|
assert config.mode == RolloutMode.OFF
|
|
assert config.percentage is None
|
|
assert config.allowlist is None
|
|
|
|
def test_rollout_config_percentage_mode(self):
|
|
"""[AC-AISVC-RES-03] Percentage mode should require percentage."""
|
|
config = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=50.0)
|
|
assert config.mode == RolloutMode.PERCENTAGE
|
|
assert config.percentage == 50.0
|
|
|
|
def test_rollout_config_percentage_mode_missing_value(self):
|
|
"""[AC-AISVC-RES-03] Percentage mode without percentage should raise error."""
|
|
with pytest.raises(ValueError, match="percentage is required"):
|
|
RolloutConfig(mode=RolloutMode.PERCENTAGE)
|
|
|
|
def test_rollout_config_allowlist_mode(self):
|
|
"""[AC-AISVC-RES-03] Allowlist mode should require allowlist."""
|
|
config = RolloutConfig(mode=RolloutMode.ALLOWLIST, allowlist=["tenant1", "tenant2"])
|
|
assert config.mode == RolloutMode.ALLOWLIST
|
|
assert config.allowlist == ["tenant1", "tenant2"]
|
|
|
|
def test_rollout_config_allowlist_mode_missing_value(self):
|
|
"""[AC-AISVC-RES-03] Allowlist mode without allowlist should raise error."""
|
|
with pytest.raises(ValueError, match="allowlist is required"):
|
|
RolloutConfig(mode=RolloutMode.ALLOWLIST)
|
|
|
|
def test_retrieval_strategy_status(self):
|
|
"""[AC-AISVC-RES-01] Status should contain all required fields."""
|
|
rollout = RolloutConfig(mode=RolloutMode.OFF)
|
|
status = RetrievalStrategyStatus(
|
|
active_strategy=StrategyType.DEFAULT,
|
|
react_mode=ReactMode.NON_REACT,
|
|
rollout=rollout,
|
|
)
|
|
assert status.active_strategy == StrategyType.DEFAULT
|
|
assert status.react_mode == ReactMode.NON_REACT
|
|
assert status.rollout.mode == RolloutMode.OFF
|
|
|
|
def test_switch_request_minimal(self):
|
|
"""[AC-AISVC-RES-02] Switch request should work with minimal fields."""
|
|
request = RetrievalStrategySwitchRequest(target_strategy=StrategyType.ENHANCED)
|
|
assert request.target_strategy == StrategyType.ENHANCED
|
|
assert request.react_mode is None
|
|
assert request.rollout is None
|
|
assert request.reason is None
|
|
|
|
def test_switch_request_full(self):
|
|
"""[AC-AISVC-RES-02,03,05] Switch request should accept all fields."""
|
|
rollout = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=30.0)
|
|
request = RetrievalStrategySwitchRequest(
|
|
target_strategy=StrategyType.ENHANCED,
|
|
react_mode=ReactMode.REACT,
|
|
rollout=rollout,
|
|
reason="Testing enhanced strategy",
|
|
)
|
|
assert request.target_strategy == StrategyType.ENHANCED
|
|
assert request.react_mode == ReactMode.REACT
|
|
assert request.rollout.percentage == 30.0
|
|
assert request.reason == "Testing enhanced strategy"
|
|
|
|
|
|
class TestRetrievalStrategyService:
|
|
"""[AC-AISVC-RES-01~15] Tests for strategy service."""
|
|
|
|
@pytest.fixture
|
|
def service(self):
|
|
"""Create a fresh service instance for each test."""
|
|
return RetrievalStrategyService()
|
|
|
|
def test_get_current_status_default(self, service):
|
|
"""[AC-AISVC-RES-01] Default status should be default strategy and non_react mode."""
|
|
status = service.get_current_status()
|
|
|
|
assert status.active_strategy == StrategyType.DEFAULT
|
|
assert status.react_mode == ReactMode.NON_REACT
|
|
assert status.rollout.mode == RolloutMode.OFF
|
|
|
|
def test_switch_strategy_to_enhanced(self, service):
|
|
"""[AC-AISVC-RES-02] Should switch to enhanced strategy."""
|
|
request = RetrievalStrategySwitchRequest(
|
|
target_strategy=StrategyType.ENHANCED,
|
|
react_mode=ReactMode.REACT,
|
|
)
|
|
|
|
response = service.switch_strategy(request)
|
|
|
|
assert response.previous.active_strategy == StrategyType.DEFAULT
|
|
assert response.current.active_strategy == StrategyType.ENHANCED
|
|
assert response.current.react_mode == ReactMode.REACT
|
|
|
|
def test_switch_strategy_with_grayscale_percentage(self, service):
|
|
"""[AC-AISVC-RES-03] Should switch with grayscale percentage."""
|
|
rollout = RolloutConfig(mode=RolloutMode.PERCENTAGE, percentage=50.0)
|
|
request = RetrievalStrategySwitchRequest(
|
|
target_strategy=StrategyType.ENHANCED,
|
|
rollout=rollout,
|
|
)
|
|
|
|
response = service.switch_strategy(request)
|
|
|
|
assert response.current.active_strategy == StrategyType.ENHANCED
|
|
assert response.current.rollout.mode == RolloutMode.PERCENTAGE
|
|
assert response.current.rollout.percentage == 50.0
|
|
|
|
def test_switch_strategy_with_allowlist(self, service):
|
|
"""[AC-AISVC-RES-03] Should switch with allowlist grayscale."""
|
|
rollout = RolloutConfig(
|
|
mode=RolloutMode.ALLOWLIST,
|
|
allowlist=["tenant_a", "tenant_b"],
|
|
)
|
|
request = RetrievalStrategySwitchRequest(
|
|
target_strategy=StrategyType.ENHANCED,
|
|
rollout=rollout,
|
|
)
|
|
|
|
response = service.switch_strategy(request)
|
|
|
|
assert response.current.rollout.mode == RolloutMode.ALLOWLIST
|
|
assert "tenant_a" in response.current.rollout.allowlist
|
|
|
|
def test_rollback_strategy(self, service):
|
|
"""[AC-AISVC-RES-07] Should rollback to previous strategy."""
|
|
request = RetrievalStrategySwitchRequest(
|
|
target_strategy=StrategyType.ENHANCED,
|
|
react_mode=ReactMode.REACT,
|
|
)
|
|
service.switch_strategy(request)
|
|
|
|
response = service.rollback_strategy()
|
|
|
|
assert response.rollback_to.active_strategy == StrategyType.DEFAULT
|
|
assert response.rollback_to.react_mode == ReactMode.NON_REACT
|
|
|
|
def test_rollback_without_previous_returns_default(self, service):
|
|
"""[AC-AISVC-RES-07] Rollback without previous should return default."""
|
|
response = service.rollback_strategy()
|
|
|
|
assert response.rollback_to.active_strategy == StrategyType.DEFAULT
|
|
|
|
def test_should_use_enhanced_strategy_default(self, service):
|
|
"""[AC-AISVC-RES-01] Default strategy should not use enhanced."""
|
|
assert service.should_use_enhanced_strategy("tenant_a") is False
|
|
|
|
def test_should_use_enhanced_strategy_with_allowlist(self, service):
|
|
"""[AC-AISVC-RES-03] Should use enhanced for tenants in allowlist."""
|
|
rollout = RolloutConfig(
|
|
mode=RolloutMode.ALLOWLIST,
|
|
allowlist=["tenant_a"],
|
|
)
|
|
request = RetrievalStrategySwitchRequest(
|
|
target_strategy=StrategyType.ENHANCED,
|
|
rollout=rollout,
|
|
)
|
|
service.switch_strategy(request)
|
|
|
|
assert service.should_use_enhanced_strategy("tenant_a") is True
|
|
assert service.should_use_enhanced_strategy("tenant_b") is False
|
|
|
|
def test_get_route_mode_react(self, service):
|
|
"""[AC-AISVC-RES-10] React mode should return react route."""
|
|
request = RetrievalStrategySwitchRequest(
|
|
target_strategy=StrategyType.ENHANCED,
|
|
react_mode=ReactMode.REACT,
|
|
)
|
|
service.switch_strategy(request)
|
|
|
|
route = service.get_route_mode("test query")
|
|
assert route == "react"
|
|
|
|
def test_get_route_mode_direct(self, service):
|
|
"""[AC-AISVC-RES-09] Non-react mode should return direct route."""
|
|
request = RetrievalStrategySwitchRequest(
|
|
target_strategy=StrategyType.DEFAULT,
|
|
react_mode=ReactMode.NON_REACT,
|
|
)
|
|
service.switch_strategy(request)
|
|
|
|
route = service.get_route_mode("test query")
|
|
assert route == "direct"
|
|
|
|
def test_get_route_mode_auto_short_query(self, service):
|
|
"""[AC-AISVC-RES-12] Short query with high confidence should use direct route."""
|
|
service._state.react_mode = None
|
|
|
|
route = service._auto_route("短问题", confidence=0.8)
|
|
|
|
assert route == "direct"
|
|
|
|
def test_get_route_mode_auto_multiple_conditions(self, service):
|
|
"""[AC-AISVC-RES-13] Query with multiple conditions should use react route."""
|
|
route = service._auto_route("查询订单状态和物流信息")
|
|
|
|
assert route == "react"
|
|
|
|
def test_get_route_mode_auto_low_confidence(self, service):
|
|
"""[AC-AISVC-RES-13] Low confidence should use react route."""
|
|
route = service._auto_route("test query", confidence=0.3)
|
|
|
|
assert route == "react"
|
|
|
|
def test_get_switch_history(self, service):
|
|
"""Should track switch history."""
|
|
request = RetrievalStrategySwitchRequest(
|
|
target_strategy=StrategyType.ENHANCED,
|
|
reason="Testing",
|
|
)
|
|
service.switch_strategy(request)
|
|
|
|
history = service.get_switch_history()
|
|
|
|
assert len(history) == 1
|
|
assert history[0]["to_strategy"] == "enhanced"
|
|
|
|
|
|
class TestRetrievalStrategyValidation:
|
|
"""[AC-AISVC-RES-04,06,08] Tests for strategy validation."""
|
|
|
|
@pytest.fixture
|
|
def service(self):
|
|
return RetrievalStrategyService()
|
|
|
|
def test_validate_default_strategy(self, service):
|
|
"""[AC-AISVC-RES-06] Default strategy should pass validation."""
|
|
request = RetrievalStrategyValidationRequest(
|
|
strategy=StrategyType.DEFAULT,
|
|
)
|
|
|
|
response = service.validate_strategy(request)
|
|
|
|
assert response.passed is True
|
|
|
|
def test_validate_enhanced_strategy(self, service):
|
|
"""[AC-AISVC-RES-06] Enhanced strategy validation."""
|
|
request = RetrievalStrategyValidationRequest(
|
|
strategy=StrategyType.ENHANCED,
|
|
)
|
|
|
|
response = service.validate_strategy(request)
|
|
|
|
assert isinstance(response.passed, bool)
|
|
assert len(response.results) > 0
|
|
|
|
def test_validate_specific_checks(self, service):
|
|
"""[AC-AISVC-RES-06] Should run specific validation checks."""
|
|
request = RetrievalStrategyValidationRequest(
|
|
strategy=StrategyType.ENHANCED,
|
|
checks=["metadata_consistency", "performance_budget"],
|
|
)
|
|
|
|
response = service.validate_strategy(request)
|
|
|
|
check_names = [r.check for r in response.results]
|
|
assert "metadata_consistency" in check_names
|
|
assert "performance_budget" in check_names
|
|
|
|
def test_check_metadata_consistency(self, service):
|
|
"""[AC-AISVC-RES-04] Metadata consistency check."""
|
|
result = service._check_metadata_consistency(StrategyType.DEFAULT)
|
|
|
|
assert result.check == "metadata_consistency"
|
|
assert result.passed is True
|
|
|
|
def test_check_rrf_config(self, service):
|
|
"""[AC-AISVC-RES-02] RRF config check."""
|
|
result = service._check_rrf_config(StrategyType.DEFAULT)
|
|
|
|
assert result.check == "rrf_config"
|
|
assert isinstance(result.passed, bool)
|
|
|
|
def test_check_performance_budget(self, service):
|
|
"""[AC-AISVC-RES-08] Performance budget check."""
|
|
result = service._check_performance_budget(
|
|
StrategyType.ENHANCED,
|
|
ReactMode.REACT,
|
|
)
|
|
|
|
assert result.check == "performance_budget"
|
|
assert isinstance(result.passed, bool)
|
|
|
|
|
|
class TestStrategyAuditService:
|
|
"""[AC-AISVC-RES-07] Tests for audit service."""
|
|
|
|
@pytest.fixture
|
|
def audit_service(self):
|
|
return StrategyAuditService(max_entries=100)
|
|
|
|
def test_log_switch_operation(self, audit_service):
|
|
"""[AC-AISVC-RES-07] Should log switch operation."""
|
|
audit_service.log(
|
|
operation="switch",
|
|
previous_strategy="default",
|
|
new_strategy="enhanced",
|
|
reason="Testing",
|
|
operator="admin",
|
|
)
|
|
|
|
entries = audit_service.get_audit_log()
|
|
assert len(entries) == 1
|
|
assert entries[0].operation == "switch"
|
|
assert entries[0].previous_strategy == "default"
|
|
assert entries[0].new_strategy == "enhanced"
|
|
|
|
def test_log_rollback_operation(self, audit_service):
|
|
"""[AC-AISVC-RES-07] Should log rollback operation."""
|
|
audit_service.log_rollback(
|
|
previous_strategy="enhanced",
|
|
new_strategy="default",
|
|
reason="Performance issue",
|
|
operator="admin",
|
|
)
|
|
|
|
entries = audit_service.get_audit_log(operation="rollback")
|
|
assert len(entries) == 1
|
|
assert entries[0].operation == "rollback"
|
|
|
|
def test_log_validation_operation(self, audit_service):
|
|
"""[AC-AISVC-RES-06] Should log validation operation."""
|
|
audit_service.log_validation(
|
|
strategy="enhanced",
|
|
checks=["metadata_consistency"],
|
|
passed=True,
|
|
)
|
|
|
|
entries = audit_service.get_audit_log(operation="validate")
|
|
assert len(entries) == 1
|
|
assert entries[0].operation == "validate"
|
|
|
|
def test_get_audit_log_with_limit(self, audit_service):
|
|
"""Should limit audit log entries."""
|
|
for i in range(10):
|
|
audit_service.log(operation="switch", new_strategy=f"strategy_{i}")
|
|
|
|
entries = audit_service.get_audit_log(limit=5)
|
|
assert len(entries) == 5
|
|
|
|
def test_get_audit_stats(self, audit_service):
|
|
"""Should return audit statistics."""
|
|
audit_service.log(operation="switch", new_strategy="enhanced")
|
|
audit_service.log(operation="rollback", new_strategy="default")
|
|
|
|
stats = audit_service.get_audit_stats()
|
|
|
|
assert stats["total_entries"] == 2
|
|
assert stats["operation_counts"]["switch"] == 1
|
|
assert stats["operation_counts"]["rollback"] == 1
|
|
|
|
def test_clear_audit_log(self, audit_service):
|
|
"""Should clear audit log."""
|
|
audit_service.log(operation="switch", new_strategy="enhanced")
|
|
assert len(audit_service.get_audit_log()) == 1
|
|
|
|
count = audit_service.clear_audit_log()
|
|
assert count == 1
|
|
assert len(audit_service.get_audit_log()) == 0
|
|
|
|
|
|
class TestStrategyMetricsService:
|
|
"""[AC-AISVC-RES-03,08] Tests for metrics service."""
|
|
|
|
@pytest.fixture
|
|
def metrics_service(self):
|
|
return StrategyMetricsService()
|
|
|
|
def test_record_request(self, metrics_service):
|
|
"""[AC-AISVC-RES-08] Should record request metrics."""
|
|
metrics_service.record_request(
|
|
latency_ms=100.0,
|
|
success=True,
|
|
route_mode="direct",
|
|
)
|
|
|
|
metrics = metrics_service.get_metrics()
|
|
assert metrics.total_requests == 1
|
|
assert metrics.successful_requests == 1
|
|
assert metrics.avg_latency_ms == 100.0
|
|
|
|
def test_record_failed_request(self, metrics_service):
|
|
"""[AC-AISVC-RES-08] Should record failed request."""
|
|
metrics_service.record_request(latency_ms=50.0, success=False)
|
|
|
|
metrics = metrics_service.get_metrics()
|
|
assert metrics.failed_requests == 1
|
|
|
|
def test_record_fallback(self, metrics_service):
|
|
"""[AC-AISVC-RES-08] Should record fallback count."""
|
|
metrics_service.record_request(
|
|
latency_ms=100.0,
|
|
success=True,
|
|
fallback=True,
|
|
)
|
|
|
|
metrics = metrics_service.get_metrics()
|
|
assert metrics.fallback_count == 1
|
|
|
|
def test_record_route_metrics(self, metrics_service):
|
|
"""[AC-AISVC-RES-08] Should track route mode metrics."""
|
|
metrics_service.record_request(latency_ms=100.0, success=True, route_mode="react")
|
|
metrics_service.record_request(latency_ms=50.0, success=True, route_mode="direct")
|
|
|
|
route_metrics = metrics_service.get_route_metrics()
|
|
assert "react" in route_metrics
|
|
assert "direct" in route_metrics
|
|
|
|
def test_get_all_metrics(self, metrics_service):
|
|
"""Should get metrics for all strategies."""
|
|
metrics_service.set_current_strategy(StrategyType.ENHANCED, ReactMode.REACT)
|
|
metrics_service.record_request(latency_ms=100.0, success=True)
|
|
|
|
all_metrics = metrics_service.get_all_metrics()
|
|
|
|
assert StrategyType.DEFAULT.value in all_metrics
|
|
assert StrategyType.ENHANCED.value in all_metrics
|
|
|
|
def test_get_performance_summary(self, metrics_service):
|
|
"""[AC-AISVC-RES-08] Should get performance summary."""
|
|
metrics_service.record_request(latency_ms=100.0, success=True)
|
|
metrics_service.record_request(latency_ms=200.0, success=True)
|
|
metrics_service.record_request(latency_ms=50.0, success=False)
|
|
|
|
summary = metrics_service.get_performance_summary()
|
|
|
|
assert summary["total_requests"] == 3
|
|
assert summary["successful_requests"] == 2
|
|
assert summary["failed_requests"] == 1
|
|
assert summary["success_rate"] == pytest.approx(0.6667, rel=0.01)
|
|
|
|
def test_check_performance_threshold_ok(self, metrics_service):
|
|
"""[AC-AISVC-RES-08] Should pass performance threshold check."""
|
|
metrics_service.record_request(latency_ms=100.0, success=True)
|
|
|
|
result = metrics_service.check_performance_threshold(
|
|
strategy=StrategyType.DEFAULT,
|
|
max_latency_ms=5000.0,
|
|
max_error_rate=0.1,
|
|
)
|
|
|
|
assert result["latency_ok"] is True
|
|
assert result["error_rate_ok"] is True
|
|
assert result["overall_ok"] is True
|
|
|
|
def test_check_performance_threshold_exceeded(self, metrics_service):
|
|
"""[AC-AISVC-RES-08] Should fail when threshold exceeded."""
|
|
metrics_service.record_request(latency_ms=6000.0, success=True)
|
|
metrics_service.record_request(latency_ms=100.0, success=False)
|
|
|
|
result = metrics_service.check_performance_threshold(
|
|
strategy=StrategyType.DEFAULT,
|
|
max_latency_ms=5000.0,
|
|
max_error_rate=0.1,
|
|
)
|
|
|
|
assert result["latency_ok"] is False or result["error_rate_ok"] is False
|
|
|
|
def test_reset_metrics(self, metrics_service):
|
|
"""Should reset metrics."""
|
|
metrics_service.record_request(latency_ms=100.0, success=True)
|
|
metrics_service.reset_metrics()
|
|
|
|
metrics = metrics_service.get_metrics()
|
|
assert metrics.total_requests == 0
|
|
|
|
|
|
class TestSingletonInstances:
|
|
"""Tests for singleton instance getters."""
|
|
|
|
def test_get_strategy_service_singleton(self):
|
|
"""Should return same strategy service instance."""
|
|
from app.services.retrieval.strategy_service import _strategy_service
|
|
|
|
import app.services.retrieval.strategy_service as module
|
|
module._strategy_service = None
|
|
|
|
service1 = get_strategy_service()
|
|
service2 = get_strategy_service()
|
|
|
|
assert service1 is service2
|
|
|
|
def test_get_audit_service_singleton(self):
|
|
"""Should return same audit service instance."""
|
|
from app.services.retrieval.strategy_audit import _audit_service
|
|
|
|
import app.services.retrieval.strategy_audit as module
|
|
module._audit_service = None
|
|
|
|
service1 = get_audit_service()
|
|
service2 = get_audit_service()
|
|
|
|
assert service1 is service2
|
|
|
|
def test_get_metrics_service_singleton(self):
|
|
"""Should return same metrics service instance."""
|
|
from app.services.retrieval.strategy_metrics import _metrics_service
|
|
|
|
import app.services.retrieval.strategy_metrics as module
|
|
module._metrics_service = None
|
|
|
|
service1 = get_metrics_service()
|
|
service2 = get_metrics_service()
|
|
|
|
assert service1 is service2
|