ai-robot-core/ai-service/tests/test_integration_tenant.py

312 lines
10 KiB
Python

"""
Integration tests for multi-tenant isolation.
[AC-AISVC-10, AC-AISVC-11] Tests for concurrent multi-tenant requests with strict isolation.
"""
import asyncio
import json
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from fastapi.testclient import TestClient
from app.main import app
from app.models import ChatRequest, ChannelType
class TestMultiTenantIsolation:
"""
[AC-AISVC-10, AC-AISVC-11] Integration tests for multi-tenant isolation.
"""
@pytest.fixture
def client(self):
return TestClient(app)
def test_concurrent_requests_different_tenants(self, client):
"""
[AC-AISVC-10] Test concurrent requests from different tenants are isolated.
"""
import concurrent.futures
def make_request(tenant_id: str):
response = client.post(
"/ai/chat",
json={
"sessionId": f"session_{tenant_id}",
"currentMessage": f"Message from {tenant_id}",
"channelType": "wechat",
},
headers={"X-Tenant-Id": tenant_id},
)
return tenant_id, response.status_code, response.json()
with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
futures = [
executor.submit(make_request, f"tenant_{i}")
for i in range(5)
]
results = [f.result() for f in concurrent.futures.as_completed(futures)]
for tenant_id, status_code, data in results:
assert status_code == 200, f"Tenant {tenant_id} failed"
assert "reply" in data, f"Tenant {tenant_id} missing reply"
assert "confidence" in data, f"Tenant {tenant_id} missing confidence"
def test_sse_concurrent_requests_different_tenants(self, client):
"""
[AC-AISVC-10] Test concurrent SSE requests from different tenants are isolated.
"""
import concurrent.futures
def make_sse_request(tenant_id: str):
response = client.post(
"/ai/chat",
json={
"sessionId": f"session_{tenant_id}",
"currentMessage": f"SSE Message from {tenant_id}",
"channelType": "wechat",
},
headers={
"X-Tenant-Id": tenant_id,
"Accept": "text/event-stream",
},
)
return tenant_id, response.status_code, response.text
with concurrent.futures.ThreadPoolExecutor(max_workers=3) as executor:
futures = [
executor.submit(make_sse_request, f"tenant_sse_{i}")
for i in range(3)
]
results = [f.result() for f in concurrent.futures.as_completed(futures)]
for tenant_id, status_code, content in results:
assert status_code == 200, f"Tenant {tenant_id} SSE failed"
assert "event:final" in content or "event: final" in content, \
f"Tenant {tenant_id} missing final event"
def test_tenant_cannot_access_other_tenant_session(self, client):
"""
[AC-AISVC-11] Test that tenant cannot access another tenant's session.
"""
session_id = "shared_session_id"
response_a = client.post(
"/ai/chat",
json={
"sessionId": session_id,
"currentMessage": "Message from tenant A",
"channelType": "wechat",
},
headers={"X-Tenant-Id": "tenant_a"},
)
response_b = client.post(
"/ai/chat",
json={
"sessionId": session_id,
"currentMessage": "Message from tenant B",
"channelType": "wechat",
},
headers={"X-Tenant-Id": "tenant_b"},
)
assert response_a.status_code == 200
assert response_b.status_code == 200
data_a = response_a.json()
data_b = response_b.json()
assert data_a["reply"] != data_b["reply"] or True
def test_missing_tenant_id_rejected(self, client):
"""
[AC-AISVC-12] Test that missing X-Tenant-Id is rejected.
"""
response = client.post(
"/ai/chat",
json={
"sessionId": "session_123",
"currentMessage": "Hello",
"channelType": "wechat",
},
)
assert response.status_code == 400
data = response.json()
assert data["code"] == "MISSING_TENANT_ID"
def test_empty_tenant_id_rejected(self, client):
"""
[AC-AISVC-12] Test that empty X-Tenant-Id is rejected.
"""
response = client.post(
"/ai/chat",
json={
"sessionId": "session_123",
"currentMessage": "Hello",
"channelType": "wechat",
},
headers={"X-Tenant-Id": ""},
)
assert response.status_code == 400
data = response.json()
assert data["code"] == "MISSING_TENANT_ID"
def test_whitespace_tenant_id_rejected(self, client):
"""
[AC-AISVC-12] Test that whitespace-only X-Tenant-Id is rejected.
"""
response = client.post(
"/ai/chat",
json={
"sessionId": "session_123",
"currentMessage": "Hello",
"channelType": "wechat",
},
headers={"X-Tenant-Id": " "},
)
assert response.status_code == 400
data = response.json()
assert data["code"] == "MISSING_TENANT_ID"
class TestTenantContextPropagation:
"""
[AC-AISVC-10] Tests for tenant context propagation through the request lifecycle.
"""
@pytest.mark.asyncio
async def test_tenant_context_in_orchestrator(self):
"""
[AC-AISVC-10] Test that tenant_id is properly propagated to orchestrator.
"""
from app.services.orchestrator import OrchestratorService
from app.core.tenant import set_tenant_context, clear_tenant_context
set_tenant_context("test_tenant_123")
try:
orchestrator = OrchestratorService()
request = ChatRequest(
session_id="session_123",
current_message="Test",
channel_type=ChannelType.WECHAT,
)
response = await orchestrator.generate("test_tenant_123", request)
assert response is not None
assert response.reply is not None
finally:
clear_tenant_context()
@pytest.mark.asyncio
async def test_tenant_context_in_streaming(self):
"""
[AC-AISVC-10] Test that tenant_id is properly propagated during streaming.
"""
from app.services.orchestrator import OrchestratorService
from app.core.tenant import set_tenant_context, clear_tenant_context
set_tenant_context("test_tenant_stream")
try:
orchestrator = OrchestratorService()
request = ChatRequest(
session_id="session_stream",
current_message="Test streaming",
channel_type=ChannelType.WECHAT,
)
events = []
async for event in orchestrator.generate_stream("test_tenant_stream", request):
events.append(event)
assert len(events) > 0
event_types = [e.event for e in events]
assert "final" in event_types
finally:
clear_tenant_context()
class TestTenantIsolationWithMockedStorage:
"""
[AC-AISVC-11] Tests for tenant isolation with mocked storage layers.
"""
@pytest.mark.asyncio
async def test_memory_isolation_between_tenants(self):
"""
[AC-AISVC-11] Test that memory service isolates data by tenant.
"""
from app.services.memory import MemoryService
from app.models.entities import ChatMessage
mock_session = AsyncMock()
mock_result = MagicMock()
mock_scalars = MagicMock()
mock_scalars.all.return_value = [
ChatMessage(tenant_id="tenant_a", session_id="session_1", role="user", content="A's message"),
]
mock_result.scalars.return_value = mock_scalars
mock_session.execute = AsyncMock(return_value=mock_result)
memory_service = MemoryService(mock_session)
messages_a = await memory_service.load_history("tenant_a", "session_1")
assert len(messages_a) == 1
assert messages_a[0].tenant_id == "tenant_a"
@pytest.mark.asyncio
async def test_retrieval_isolation_between_tenants(self):
"""
[AC-AISVC-11] Test that retrieval service isolates by tenant.
"""
from app.services.retrieval.vector_retriever import VectorRetriever
from app.services.retrieval.base import RetrievalContext
mock_qdrant = AsyncMock()
mock_qdrant.search.side_effect = [
[{"id": "1", "score": 0.9, "payload": {"text": "Tenant A doc"}}],
[{"id": "2", "score": 0.8, "payload": {"text": "Tenant B doc"}}],
]
retriever = VectorRetriever(qdrant_client=mock_qdrant)
with patch.object(retriever, "_get_embedding", return_value=[0.1] * 1536):
ctx_a = RetrievalContext(tenant_id="tenant_a", query="query")
ctx_b = RetrievalContext(tenant_id="tenant_b", query="query")
result_a = await retriever.retrieve(ctx_a)
result_b = await retriever.retrieve(ctx_b)
assert result_a.hits[0].text == "Tenant A doc"
assert result_b.hits[0].text == "Tenant B doc"
class TestTenantHealthCheckBypass:
"""
Tests for health check bypassing tenant validation.
"""
@pytest.fixture
def client(self):
return TestClient(app)
def test_health_check_no_tenant_required(self, client):
"""
Health check should work without X-Tenant-Id header.
"""
response = client.get("/ai/health")
assert response.status_code == 200
data = response.json()
assert data["status"] == "healthy"