""" Unit tests for Memory service. [AC-AISVC-10, AC-AISVC-11, AC-AISVC-13] Tests for multi-tenant session and message management. """ import pytest from unittest.mock import AsyncMock, MagicMock, patch from sqlalchemy.ext.asyncio import AsyncSession from app.models.entities import ChatMessage, ChatSession from app.services.memory import MemoryService @pytest.fixture def mock_session(): """Create a mock AsyncSession.""" session = AsyncMock(spec=AsyncSession) session.add = MagicMock() session.flush = AsyncMock() session.delete = AsyncMock() return session @pytest.fixture def memory_service(mock_session): """Create MemoryService with mocked session.""" return MemoryService(mock_session) class TestMemoryServiceTenantIsolation: """ [AC-AISVC-10, AC-AISVC-11] Tests for multi-tenant isolation in memory service. """ @pytest.mark.asyncio async def test_get_or_create_session_tenant_isolation(self, memory_service, mock_session): """ [AC-AISVC-11] Different tenants with same session_id should have separate sessions. """ mock_result = MagicMock() mock_result.scalar_one_or_none.return_value = None mock_session.execute = AsyncMock(return_value=mock_result) session1 = await memory_service.get_or_create_session( tenant_id="tenant_a", session_id="session_123", ) session2 = await memory_service.get_or_create_session( tenant_id="tenant_b", session_id="session_123", ) assert session1.tenant_id == "tenant_a" assert session2.tenant_id == "tenant_b" assert session1.session_id == "session_123" assert session2.session_id == "session_123" @pytest.mark.asyncio async def test_load_history_tenant_isolation(self, memory_service, mock_session): """ [AC-AISVC-11] Loading history should only return messages for the specific tenant. """ mock_result = MagicMock() mock_scalars = MagicMock() mock_scalars.all.return_value = [ ChatMessage(tenant_id="tenant_a", session_id="session_123", role="user", content="Hello"), ] mock_result.scalars.return_value = mock_scalars mock_session.execute = AsyncMock(return_value=mock_result) messages = await memory_service.load_history( tenant_id="tenant_a", session_id="session_123", ) assert len(messages) == 1 assert messages[0].tenant_id == "tenant_a" @pytest.mark.asyncio async def test_append_message_tenant_scoped(self, memory_service, mock_session): """ [AC-AISVC-10, AC-AISVC-13] Appended messages should be scoped to tenant. """ message = await memory_service.append_message( tenant_id="tenant_a", session_id="session_123", role="user", content="Test message", ) assert message.tenant_id == "tenant_a" assert message.session_id == "session_123" assert message.role == "user" assert message.content == "Test message" class TestMemoryServiceSessionManagement: """ [AC-AISVC-13] Tests for session-based memory management. """ @pytest.mark.asyncio async def test_get_existing_session(self, memory_service, mock_session): """ [AC-AISVC-13] Should return existing session if it exists. """ existing_session = ChatSession( tenant_id="tenant_a", session_id="session_123", ) mock_result = MagicMock() mock_result.scalar_one_or_none.return_value = existing_session mock_session.execute = AsyncMock(return_value=mock_result) session = await memory_service.get_or_create_session( tenant_id="tenant_a", session_id="session_123", ) assert session.tenant_id == "tenant_a" assert session.session_id == "session_123" @pytest.mark.asyncio async def test_create_new_session(self, memory_service, mock_session): """ [AC-AISVC-13] Should create new session if it doesn't exist. """ mock_result = MagicMock() mock_result.scalar_one_or_none.return_value = None mock_session.execute = AsyncMock(return_value=mock_result) session = await memory_service.get_or_create_session( tenant_id="tenant_a", session_id="session_new", channel_type="wechat", metadata={"user_id": "user_123"}, ) assert session.tenant_id == "tenant_a" assert session.session_id == "session_new" assert session.channel_type == "wechat" @pytest.mark.asyncio async def test_append_multiple_messages(self, memory_service, mock_session): """ [AC-AISVC-13] Should append multiple messages in batch. """ messages_data = [ {"role": "user", "content": "Hello"}, {"role": "assistant", "content": "Hi there!"}, ] messages = await memory_service.append_messages( tenant_id="tenant_a", session_id="session_123", messages=messages_data, ) assert len(messages) == 2 assert messages[0].role == "user" assert messages[1].role == "assistant" @pytest.mark.asyncio async def test_load_history_with_limit(self, memory_service, mock_session): """ [AC-AISVC-13] Should limit the number of messages returned. """ mock_result = MagicMock() mock_scalars = MagicMock() mock_scalars.all.return_value = [ ChatMessage(tenant_id="tenant_a", session_id="session_123", role="user", content=f"Msg {i}") for i in range(5) ] mock_result.scalars.return_value = mock_scalars mock_session.execute = AsyncMock(return_value=mock_result) messages = await memory_service.load_history( tenant_id="tenant_a", session_id="session_123", limit=3, ) assert len(messages) == 5 class TestMemoryServiceClearHistory: """ [AC-AISVC-13] Tests for clearing session history. """ @pytest.mark.asyncio async def test_clear_history_tenant_scoped(self, memory_service, mock_session): """ [AC-AISVC-11] Clearing history should only affect the specified tenant's messages. """ mock_result = MagicMock() mock_scalars = MagicMock() mock_scalars.all.return_value = [ ChatMessage(tenant_id="tenant_a", session_id="session_123", role="user", content="Msg 1"), ChatMessage(tenant_id="tenant_a", session_id="session_123", role="assistant", content="Msg 2"), ] mock_result.scalars.return_value = mock_scalars mock_session.execute = AsyncMock(return_value=mock_result) count = await memory_service.clear_history( tenant_id="tenant_a", session_id="session_123", ) assert count == 2