ai-robot-core/ai-service/tests/test_retrieval_strategy_int...

354 lines
9.8 KiB
Python

"""
Integration tests for Retrieval Strategy API.
[AC-AISVC-RES-01~15] End-to-end tests for strategy management endpoints.
Tests the full API flow:
- GET /strategy/retrieval/current
- POST /strategy/retrieval/switch
- POST /strategy/retrieval/validate
- POST /strategy/retrieval/rollback
"""
import json
import pytest
from unittest.mock import patch, MagicMock, AsyncMock
from fastapi.testclient import TestClient
from app.main import app
@pytest.fixture(autouse=True)
def mock_api_key_service():
"""
Mock API key service to bypass authentication in tests.
"""
mock_service = MagicMock()
mock_service._initialized = True
mock_service._keys_cache = {"test-api-key": MagicMock()}
mock_validation = MagicMock()
mock_validation.ok = True
mock_validation.reason = None
mock_service.validate_key_with_context.return_value = mock_validation
with patch("app.services.api_key.get_api_key_service", return_value=mock_service):
yield mock_service
@pytest.fixture(autouse=True)
def reset_strategy_state():
"""
Reset strategy state before and after each test.
"""
from app.services.retrieval.strategy.strategy_router import get_strategy_router, set_strategy_router
from app.services.retrieval.strategy.config import RetrievalStrategyConfig
set_strategy_router(None)
router = get_strategy_router()
router.update_config(RetrievalStrategyConfig())
yield
set_strategy_router(None)
router = get_strategy_router()
router.update_config(RetrievalStrategyConfig())
class TestRetrievalStrategyAPIIntegration:
"""
[AC-AISVC-RES-01~15] Integration tests for retrieval strategy API.
"""
@pytest.fixture
def client(self):
return TestClient(app)
@pytest.fixture
def valid_headers(self):
return {
"X-Tenant-Id": "test@ash@2026",
"X-API-Key": "test-api-key",
}
def test_get_current_strategy(self, client, valid_headers):
"""
[AC-AISVC-RES-01] GET /current should return strategy status.
"""
response = client.get(
"/strategy/retrieval/current",
headers=valid_headers,
)
assert response.status_code == 200
data = response.json()
assert "active_strategy" in data
assert "grayscale" in data
assert data["active_strategy"] in ["default", "enhanced"]
def test_switch_strategy_to_enhanced(self, client, valid_headers):
"""
[AC-AISVC-RES-02] POST /switch should switch to enhanced strategy.
"""
response = client.post(
"/strategy/retrieval/switch",
json={
"active_strategy": "enhanced",
},
headers=valid_headers,
)
assert response.status_code == 200
data = response.json()
assert "success" in data
assert data["success"] is True
assert data["current_strategy"] == "enhanced"
def test_switch_strategy_with_grayscale_percentage(self, client, valid_headers):
"""
[AC-AISVC-RES-03] POST /switch should accept grayscale percentage.
"""
response = client.post(
"/strategy/retrieval/switch",
json={
"active_strategy": "enhanced",
"grayscale": {
"enabled": True,
"percentage": 30.0,
},
},
headers=valid_headers,
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
def test_switch_strategy_with_allowlist(self, client, valid_headers):
"""
[AC-AISVC-RES-03] POST /switch should accept allowlist.
"""
response = client.post(
"/strategy/retrieval/switch",
json={
"active_strategy": "enhanced",
"grayscale": {
"enabled": True,
"allowlist": ["tenant_a", "tenant_b"],
},
},
headers=valid_headers,
)
assert response.status_code == 200
data = response.json()
assert data["success"] is True
def test_validate_strategy(self, client, valid_headers):
"""
[AC-AISVC-RES-06] POST /validate should validate strategy.
"""
response = client.post(
"/strategy/retrieval/validate",
json={
"active_strategy": "enhanced",
},
headers=valid_headers,
)
assert response.status_code == 200
data = response.json()
assert "valid" in data
assert "errors" in data
assert isinstance(data["valid"], bool)
def test_validate_default_strategy(self, client, valid_headers):
"""
[AC-AISVC-RES-06] Default strategy should pass validation.
"""
response = client.post(
"/strategy/retrieval/validate",
json={
"active_strategy": "default",
},
headers=valid_headers,
)
assert response.status_code == 200
data = response.json()
assert data["valid"] is True
def test_rollback_strategy(self, client, valid_headers):
"""
[AC-AISVC-RES-07] POST /rollback should rollback to default.
"""
client.post(
"/strategy/retrieval/switch",
json={
"active_strategy": "enhanced",
},
headers=valid_headers,
)
response = client.post(
"/strategy/retrieval/rollback",
headers=valid_headers,
)
assert response.status_code == 200
data = response.json()
assert "success" in data
assert data["current_strategy"] == "default"
class TestRetrievalStrategyAPIValidation:
"""
[AC-AISVC-RES-03] Tests for API request validation.
"""
@pytest.fixture
def client(self):
return TestClient(app)
@pytest.fixture
def valid_headers(self):
return {
"X-Tenant-Id": "test@ash@2026",
"X-API-Key": "test-api-key",
}
def test_switch_invalid_strategy(self, client, valid_headers):
"""
[AC-AISVC-RES-03] Invalid strategy value should return error.
"""
response = client.post(
"/strategy/retrieval/switch",
json={
"active_strategy": "invalid_strategy",
},
headers=valid_headers,
)
assert response.status_code in [400, 422, 500]
def test_switch_percentage_out_of_range(self, client, valid_headers):
"""
[AC-AISVC-RES-03] Percentage > 100 should return validation error.
"""
response = client.post(
"/strategy/retrieval/switch",
json={
"active_strategy": "enhanced",
"grayscale": {
"percentage": 150.0,
},
},
headers=valid_headers,
)
assert response.status_code in [400, 422]
class TestRetrievalStrategyAPIFlow:
"""
[AC-AISVC-RES-01~15] Tests for complete API flow scenarios.
"""
@pytest.fixture
def client(self):
return TestClient(app)
@pytest.fixture
def valid_headers(self):
return {
"X-Tenant-Id": "test@ash@2026",
"X-API-Key": "test-api-key",
}
def test_complete_strategy_lifecycle(self, client, valid_headers):
"""
[AC-AISVC-RES-01~07] Test complete strategy lifecycle:
1. Get current strategy
2. Switch to enhanced
3. Validate
4. Rollback
5. Verify back to default
"""
current = client.get(
"/strategy/retrieval/current",
headers=valid_headers,
)
assert current.status_code == 200
assert current.json()["active_strategy"] == "default"
switch = client.post(
"/strategy/retrieval/switch",
json={
"active_strategy": "enhanced",
"grayscale": {"enabled": True, "percentage": 50.0},
},
headers=valid_headers,
)
assert switch.status_code == 200
assert switch.json()["current_strategy"] == "enhanced"
validate = client.post(
"/strategy/retrieval/validate",
json={"active_strategy": "enhanced"},
headers=valid_headers,
)
assert validate.status_code == 200
rollback = client.post(
"/strategy/retrieval/rollback",
headers=valid_headers,
)
assert rollback.status_code == 200
assert rollback.json()["current_strategy"] == "default"
final = client.get(
"/strategy/retrieval/current",
headers=valid_headers,
)
assert final.status_code == 200
assert final.json()["active_strategy"] == "default"
class TestRetrievalStrategyAPIMissingTenant:
"""
Tests for API behavior without tenant ID.
"""
@pytest.fixture
def client(self):
return TestClient(app)
@pytest.fixture
def api_key_headers(self):
return {"X-API-Key": "test-api-key"}
def test_current_without_tenant(self, client, api_key_headers):
"""
Missing X-Tenant-Id should return 400.
"""
response = client.get(
"/strategy/retrieval/current",
headers=api_key_headers,
)
assert response.status_code == 400
def test_switch_without_tenant(self, client, api_key_headers):
"""
Missing X-Tenant-Id should return 400.
"""
response = client.post(
"/strategy/retrieval/switch",
json={"active_strategy": "enhanced"},
headers=api_key_headers,
)
assert response.status_code == 400