211 lines
6.9 KiB
Python
211 lines
6.9 KiB
Python
"""
|
|
Unit tests for Memory service.
|
|
[AC-AISVC-10, AC-AISVC-11, AC-AISVC-13] Tests for multi-tenant session and message management.
|
|
"""
|
|
|
|
import pytest
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.entities import ChatMessage, ChatSession
|
|
from app.services.memory import MemoryService
|
|
|
|
|
|
@pytest.fixture
|
|
def mock_session():
|
|
"""Create a mock AsyncSession."""
|
|
session = AsyncMock(spec=AsyncSession)
|
|
session.add = MagicMock()
|
|
session.flush = AsyncMock()
|
|
session.delete = AsyncMock()
|
|
return session
|
|
|
|
|
|
@pytest.fixture
|
|
def memory_service(mock_session):
|
|
"""Create MemoryService with mocked session."""
|
|
return MemoryService(mock_session)
|
|
|
|
|
|
class TestMemoryServiceTenantIsolation:
|
|
"""
|
|
[AC-AISVC-10, AC-AISVC-11] Tests for multi-tenant isolation in memory service.
|
|
"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_or_create_session_tenant_isolation(self, memory_service, mock_session):
|
|
"""
|
|
[AC-AISVC-11] Different tenants with same session_id should have separate sessions.
|
|
"""
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = None
|
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
session1 = await memory_service.get_or_create_session(
|
|
tenant_id="tenant_a",
|
|
session_id="session_123",
|
|
)
|
|
session2 = await memory_service.get_or_create_session(
|
|
tenant_id="tenant_b",
|
|
session_id="session_123",
|
|
)
|
|
|
|
assert session1.tenant_id == "tenant_a"
|
|
assert session2.tenant_id == "tenant_b"
|
|
assert session1.session_id == "session_123"
|
|
assert session2.session_id == "session_123"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_load_history_tenant_isolation(self, memory_service, mock_session):
|
|
"""
|
|
[AC-AISVC-11] Loading history should only return messages for the specific tenant.
|
|
"""
|
|
mock_result = MagicMock()
|
|
mock_scalars = MagicMock()
|
|
mock_scalars.all.return_value = [
|
|
ChatMessage(tenant_id="tenant_a", session_id="session_123", role="user", content="Hello"),
|
|
]
|
|
mock_result.scalars.return_value = mock_scalars
|
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
messages = await memory_service.load_history(
|
|
tenant_id="tenant_a",
|
|
session_id="session_123",
|
|
)
|
|
|
|
assert len(messages) == 1
|
|
assert messages[0].tenant_id == "tenant_a"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_append_message_tenant_scoped(self, memory_service, mock_session):
|
|
"""
|
|
[AC-AISVC-10, AC-AISVC-13] Appended messages should be scoped to tenant.
|
|
"""
|
|
message = await memory_service.append_message(
|
|
tenant_id="tenant_a",
|
|
session_id="session_123",
|
|
role="user",
|
|
content="Test message",
|
|
)
|
|
|
|
assert message.tenant_id == "tenant_a"
|
|
assert message.session_id == "session_123"
|
|
assert message.role == "user"
|
|
assert message.content == "Test message"
|
|
|
|
|
|
class TestMemoryServiceSessionManagement:
|
|
"""
|
|
[AC-AISVC-13] Tests for session-based memory management.
|
|
"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_get_existing_session(self, memory_service, mock_session):
|
|
"""
|
|
[AC-AISVC-13] Should return existing session if it exists.
|
|
"""
|
|
existing_session = ChatSession(
|
|
tenant_id="tenant_a",
|
|
session_id="session_123",
|
|
)
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = existing_session
|
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
session = await memory_service.get_or_create_session(
|
|
tenant_id="tenant_a",
|
|
session_id="session_123",
|
|
)
|
|
|
|
assert session.tenant_id == "tenant_a"
|
|
assert session.session_id == "session_123"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_create_new_session(self, memory_service, mock_session):
|
|
"""
|
|
[AC-AISVC-13] Should create new session if it doesn't exist.
|
|
"""
|
|
mock_result = MagicMock()
|
|
mock_result.scalar_one_or_none.return_value = None
|
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
session = await memory_service.get_or_create_session(
|
|
tenant_id="tenant_a",
|
|
session_id="session_new",
|
|
channel_type="wechat",
|
|
metadata={"user_id": "user_123"},
|
|
)
|
|
|
|
assert session.tenant_id == "tenant_a"
|
|
assert session.session_id == "session_new"
|
|
assert session.channel_type == "wechat"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_append_multiple_messages(self, memory_service, mock_session):
|
|
"""
|
|
[AC-AISVC-13] Should append multiple messages in batch.
|
|
"""
|
|
messages_data = [
|
|
{"role": "user", "content": "Hello"},
|
|
{"role": "assistant", "content": "Hi there!"},
|
|
]
|
|
|
|
messages = await memory_service.append_messages(
|
|
tenant_id="tenant_a",
|
|
session_id="session_123",
|
|
messages=messages_data,
|
|
)
|
|
|
|
assert len(messages) == 2
|
|
assert messages[0].role == "user"
|
|
assert messages[1].role == "assistant"
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_load_history_with_limit(self, memory_service, mock_session):
|
|
"""
|
|
[AC-AISVC-13] Should limit the number of messages returned.
|
|
"""
|
|
mock_result = MagicMock()
|
|
mock_scalars = MagicMock()
|
|
mock_scalars.all.return_value = [
|
|
ChatMessage(tenant_id="tenant_a", session_id="session_123", role="user", content=f"Msg {i}")
|
|
for i in range(5)
|
|
]
|
|
mock_result.scalars.return_value = mock_scalars
|
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
messages = await memory_service.load_history(
|
|
tenant_id="tenant_a",
|
|
session_id="session_123",
|
|
limit=3,
|
|
)
|
|
|
|
assert len(messages) == 5
|
|
|
|
|
|
class TestMemoryServiceClearHistory:
|
|
"""
|
|
[AC-AISVC-13] Tests for clearing session history.
|
|
"""
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_clear_history_tenant_scoped(self, memory_service, mock_session):
|
|
"""
|
|
[AC-AISVC-11] Clearing history should only affect the specified tenant's messages.
|
|
"""
|
|
mock_result = MagicMock()
|
|
mock_scalars = MagicMock()
|
|
mock_scalars.all.return_value = [
|
|
ChatMessage(tenant_id="tenant_a", session_id="session_123", role="user", content="Msg 1"),
|
|
ChatMessage(tenant_id="tenant_a", session_id="session_123", role="assistant", content="Msg 2"),
|
|
]
|
|
mock_result.scalars.return_value = mock_scalars
|
|
mock_session.execute = AsyncMock(return_value=mock_result)
|
|
|
|
count = await memory_service.clear_history(
|
|
tenant_id="tenant_a",
|
|
session_id="session_123",
|
|
)
|
|
|
|
assert count == 2
|