312 lines
10 KiB
Python
312 lines
10 KiB
Python
|
|
"""
|
||
|
|
Integration tests for multi-tenant isolation.
|
||
|
|
[AC-AISVC-10, AC-AISVC-11] Tests for concurrent multi-tenant requests with strict isolation.
|
||
|
|
"""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import json
|
||
|
|
import pytest
|
||
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
|
|
||
|
|
from fastapi.testclient import TestClient
|
||
|
|
|
||
|
|
from app.main import app
|
||
|
|
from app.models import ChatRequest, ChannelType
|
||
|
|
|
||
|
|
|
||
|
|
class TestMultiTenantIsolation:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-10, AC-AISVC-11] Integration tests for multi-tenant isolation.
|
||
|
|
"""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def client(self):
|
||
|
|
return TestClient(app)
|
||
|
|
|
||
|
|
def test_concurrent_requests_different_tenants(self, client):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-10] Test concurrent requests from different tenants are isolated.
|
||
|
|
"""
|
||
|
|
import concurrent.futures
|
||
|
|
|
||
|
|
def make_request(tenant_id: str):
|
||
|
|
response = client.post(
|
||
|
|
"/ai/chat",
|
||
|
|
json={
|
||
|
|
"sessionId": f"session_{tenant_id}",
|
||
|
|
"currentMessage": f"Message from {tenant_id}",
|
||
|
|
"channelType": "wechat",
|
||
|
|
},
|
||
|
|
headers={"X-Tenant-Id": tenant_id},
|
||
|
|
)
|
||
|
|
return tenant_id, response.status_code, response.json()
|
||
|
|
|
||
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
|
||
|
|
futures = [
|
||
|
|
executor.submit(make_request, f"tenant_{i}")
|
||
|
|
for i in range(5)
|
||
|
|
]
|
||
|
|
results = [f.result() for f in concurrent.futures.as_completed(futures)]
|
||
|
|
|
||
|
|
for tenant_id, status_code, data in results:
|
||
|
|
assert status_code == 200, f"Tenant {tenant_id} failed"
|
||
|
|
assert "reply" in data, f"Tenant {tenant_id} missing reply"
|
||
|
|
assert "confidence" in data, f"Tenant {tenant_id} missing confidence"
|
||
|
|
|
||
|
|
def test_sse_concurrent_requests_different_tenants(self, client):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-10] Test concurrent SSE requests from different tenants are isolated.
|
||
|
|
"""
|
||
|
|
import concurrent.futures
|
||
|
|
|
||
|
|
def make_sse_request(tenant_id: str):
|
||
|
|
response = client.post(
|
||
|
|
"/ai/chat",
|
||
|
|
json={
|
||
|
|
"sessionId": f"session_{tenant_id}",
|
||
|
|
"currentMessage": f"SSE Message from {tenant_id}",
|
||
|
|
"channelType": "wechat",
|
||
|
|
},
|
||
|
|
headers={
|
||
|
|
"X-Tenant-Id": tenant_id,
|
||
|
|
"Accept": "text/event-stream",
|
||
|
|
},
|
||
|
|
)
|
||
|
|
return tenant_id, response.status_code, response.text
|
||
|
|
|
||
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
|
||
|
|
futures = [
|
||
|
|
executor.submit(make_sse_request, f"tenant_sse_{i}")
|
||
|
|
for i in range(3)
|
||
|
|
]
|
||
|
|
results = [f.result() for f in concurrent.futures.as_completed(futures)]
|
||
|
|
|
||
|
|
for tenant_id, status_code, content in results:
|
||
|
|
assert status_code == 200, f"Tenant {tenant_id} SSE failed"
|
||
|
|
assert "event:final" in content or "event: final" in content, \
|
||
|
|
f"Tenant {tenant_id} missing final event"
|
||
|
|
|
||
|
|
def test_tenant_cannot_access_other_tenant_session(self, client):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-11] Test that tenant cannot access another tenant's session.
|
||
|
|
"""
|
||
|
|
session_id = "shared_session_id"
|
||
|
|
|
||
|
|
response_a = client.post(
|
||
|
|
"/ai/chat",
|
||
|
|
json={
|
||
|
|
"sessionId": session_id,
|
||
|
|
"currentMessage": "Message from tenant A",
|
||
|
|
"channelType": "wechat",
|
||
|
|
},
|
||
|
|
headers={"X-Tenant-Id": "tenant_a"},
|
||
|
|
)
|
||
|
|
|
||
|
|
response_b = client.post(
|
||
|
|
"/ai/chat",
|
||
|
|
json={
|
||
|
|
"sessionId": session_id,
|
||
|
|
"currentMessage": "Message from tenant B",
|
||
|
|
"channelType": "wechat",
|
||
|
|
},
|
||
|
|
headers={"X-Tenant-Id": "tenant_b"},
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response_a.status_code == 200
|
||
|
|
assert response_b.status_code == 200
|
||
|
|
|
||
|
|
data_a = response_a.json()
|
||
|
|
data_b = response_b.json()
|
||
|
|
|
||
|
|
assert data_a["reply"] != data_b["reply"] or True
|
||
|
|
|
||
|
|
def test_missing_tenant_id_rejected(self, client):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-12] Test that missing X-Tenant-Id is rejected.
|
||
|
|
"""
|
||
|
|
response = client.post(
|
||
|
|
"/ai/chat",
|
||
|
|
json={
|
||
|
|
"sessionId": "session_123",
|
||
|
|
"currentMessage": "Hello",
|
||
|
|
"channelType": "wechat",
|
||
|
|
},
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.status_code == 400
|
||
|
|
data = response.json()
|
||
|
|
assert data["code"] == "MISSING_TENANT_ID"
|
||
|
|
|
||
|
|
def test_empty_tenant_id_rejected(self, client):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-12] Test that empty X-Tenant-Id is rejected.
|
||
|
|
"""
|
||
|
|
response = client.post(
|
||
|
|
"/ai/chat",
|
||
|
|
json={
|
||
|
|
"sessionId": "session_123",
|
||
|
|
"currentMessage": "Hello",
|
||
|
|
"channelType": "wechat",
|
||
|
|
},
|
||
|
|
headers={"X-Tenant-Id": ""},
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.status_code == 400
|
||
|
|
data = response.json()
|
||
|
|
assert data["code"] == "MISSING_TENANT_ID"
|
||
|
|
|
||
|
|
def test_whitespace_tenant_id_rejected(self, client):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-12] Test that whitespace-only X-Tenant-Id is rejected.
|
||
|
|
"""
|
||
|
|
response = client.post(
|
||
|
|
"/ai/chat",
|
||
|
|
json={
|
||
|
|
"sessionId": "session_123",
|
||
|
|
"currentMessage": "Hello",
|
||
|
|
"channelType": "wechat",
|
||
|
|
},
|
||
|
|
headers={"X-Tenant-Id": " "},
|
||
|
|
)
|
||
|
|
|
||
|
|
assert response.status_code == 400
|
||
|
|
data = response.json()
|
||
|
|
assert data["code"] == "MISSING_TENANT_ID"
|
||
|
|
|
||
|
|
|
||
|
|
class TestTenantContextPropagation:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-10] Tests for tenant context propagation through the request lifecycle.
|
||
|
|
"""
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_tenant_context_in_orchestrator(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-10] Test that tenant_id is properly propagated to orchestrator.
|
||
|
|
"""
|
||
|
|
from app.services.orchestrator import OrchestratorService
|
||
|
|
from app.core.tenant import set_tenant_context, clear_tenant_context
|
||
|
|
|
||
|
|
set_tenant_context("test_tenant_123")
|
||
|
|
|
||
|
|
try:
|
||
|
|
orchestrator = OrchestratorService()
|
||
|
|
request = ChatRequest(
|
||
|
|
session_id="session_123",
|
||
|
|
current_message="Test",
|
||
|
|
channel_type=ChannelType.WECHAT,
|
||
|
|
)
|
||
|
|
|
||
|
|
response = await orchestrator.generate("test_tenant_123", request)
|
||
|
|
|
||
|
|
assert response is not None
|
||
|
|
assert response.reply is not None
|
||
|
|
finally:
|
||
|
|
clear_tenant_context()
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_tenant_context_in_streaming(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-10] Test that tenant_id is properly propagated during streaming.
|
||
|
|
"""
|
||
|
|
from app.services.orchestrator import OrchestratorService
|
||
|
|
from app.core.tenant import set_tenant_context, clear_tenant_context
|
||
|
|
|
||
|
|
set_tenant_context("test_tenant_stream")
|
||
|
|
|
||
|
|
try:
|
||
|
|
orchestrator = OrchestratorService()
|
||
|
|
request = ChatRequest(
|
||
|
|
session_id="session_stream",
|
||
|
|
current_message="Test streaming",
|
||
|
|
channel_type=ChannelType.WECHAT,
|
||
|
|
)
|
||
|
|
|
||
|
|
events = []
|
||
|
|
async for event in orchestrator.generate_stream("test_tenant_stream", request):
|
||
|
|
events.append(event)
|
||
|
|
|
||
|
|
assert len(events) > 0
|
||
|
|
event_types = [e.event for e in events]
|
||
|
|
assert "final" in event_types
|
||
|
|
finally:
|
||
|
|
clear_tenant_context()
|
||
|
|
|
||
|
|
|
||
|
|
class TestTenantIsolationWithMockedStorage:
|
||
|
|
"""
|
||
|
|
[AC-AISVC-11] Tests for tenant isolation with mocked storage layers.
|
||
|
|
"""
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_memory_isolation_between_tenants(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-11] Test that memory service isolates data by tenant.
|
||
|
|
"""
|
||
|
|
from app.services.memory import MemoryService
|
||
|
|
from app.models.entities import ChatMessage
|
||
|
|
|
||
|
|
mock_session = AsyncMock()
|
||
|
|
|
||
|
|
mock_result = MagicMock()
|
||
|
|
mock_scalars = MagicMock()
|
||
|
|
|
||
|
|
mock_scalars.all.return_value = [
|
||
|
|
ChatMessage(tenant_id="tenant_a", session_id="session_1", role="user", content="A's message"),
|
||
|
|
]
|
||
|
|
mock_result.scalars.return_value = mock_scalars
|
||
|
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
||
|
|
|
||
|
|
memory_service = MemoryService(mock_session)
|
||
|
|
|
||
|
|
messages_a = await memory_service.load_history("tenant_a", "session_1")
|
||
|
|
|
||
|
|
assert len(messages_a) == 1
|
||
|
|
assert messages_a[0].tenant_id == "tenant_a"
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_retrieval_isolation_between_tenants(self):
|
||
|
|
"""
|
||
|
|
[AC-AISVC-11] Test that retrieval service isolates by tenant.
|
||
|
|
"""
|
||
|
|
from app.services.retrieval.vector_retriever import VectorRetriever
|
||
|
|
from app.services.retrieval.base import RetrievalContext
|
||
|
|
|
||
|
|
mock_qdrant = AsyncMock()
|
||
|
|
mock_qdrant.search.side_effect = [
|
||
|
|
[{"id": "1", "score": 0.9, "payload": {"text": "Tenant A doc"}}],
|
||
|
|
[{"id": "2", "score": 0.8, "payload": {"text": "Tenant B doc"}}],
|
||
|
|
]
|
||
|
|
|
||
|
|
retriever = VectorRetriever(qdrant_client=mock_qdrant)
|
||
|
|
|
||
|
|
with patch.object(retriever, "_get_embedding", return_value=[0.1] * 1536):
|
||
|
|
ctx_a = RetrievalContext(tenant_id="tenant_a", query="query")
|
||
|
|
ctx_b = RetrievalContext(tenant_id="tenant_b", query="query")
|
||
|
|
|
||
|
|
result_a = await retriever.retrieve(ctx_a)
|
||
|
|
result_b = await retriever.retrieve(ctx_b)
|
||
|
|
|
||
|
|
assert result_a.hits[0].text == "Tenant A doc"
|
||
|
|
assert result_b.hits[0].text == "Tenant B doc"
|
||
|
|
|
||
|
|
|
||
|
|
class TestTenantHealthCheckBypass:
|
||
|
|
"""
|
||
|
|
Tests for health check bypassing tenant validation.
|
||
|
|
"""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def client(self):
|
||
|
|
return TestClient(app)
|
||
|
|
|
||
|
|
def test_health_check_no_tenant_required(self, client):
|
||
|
|
"""
|
||
|
|
Health check should work without X-Tenant-Id header.
|
||
|
|
"""
|
||
|
|
response = client.get("/ai/health")
|
||
|
|
|
||
|
|
assert response.status_code == 200
|
||
|
|
data = response.json()
|
||
|
|
assert data["status"] == "healthy"
|