ai-robot-core/ai-service/app/services/memory.py

171 lines
5.0 KiB
Python
Raw Normal View History

"""
Memory service for AI Service.
[AC-AISVC-13] Session-based memory management with tenant isolation.
"""
import logging
from typing import Sequence
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col
from app.models.entities import ChatMessage, ChatMessageCreate, ChatSession, ChatSessionCreate
logger = logging.getLogger(__name__)
class MemoryService:
"""
[AC-AISVC-13] Memory service for session-based conversation history.
All operations are scoped by (tenant_id, session_id) for multi-tenant isolation.
"""
def __init__(self, session: AsyncSession):
self._session = session
async def get_or_create_session(
self,
tenant_id: str,
session_id: str,
channel_type: str | None = None,
metadata: dict | None = None,
) -> ChatSession:
"""
[AC-AISVC-13] Get existing session or create a new one.
Ensures tenant isolation by querying with tenant_id.
"""
stmt = select(ChatSession).where(
ChatSession.tenant_id == tenant_id,
ChatSession.session_id == session_id,
)
result = await self._session.execute(stmt)
existing_session = result.scalar_one_or_none()
if existing_session:
logger.info(
f"[AC-AISVC-13] Found existing session: tenant={tenant_id}, session={session_id}"
)
return existing_session
new_session = ChatSession(
tenant_id=tenant_id,
session_id=session_id,
channel_type=channel_type,
metadata_=metadata,
)
self._session.add(new_session)
await self._session.flush()
logger.info(
f"[AC-AISVC-13] Created new session: tenant={tenant_id}, session={session_id}"
)
return new_session
async def load_history(
self,
tenant_id: str,
session_id: str,
limit: int | None = None,
) -> Sequence[ChatMessage]:
"""
[AC-AISVC-13] Load conversation history for a session.
All queries are filtered by tenant_id to ensure isolation.
"""
stmt = (
select(ChatMessage)
.where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.session_id == session_id,
)
.order_by(col(ChatMessage.created_at).asc())
)
if limit:
stmt = stmt.limit(limit)
result = await self._session.execute(stmt)
messages = result.scalars().all()
logger.info(
f"[AC-AISVC-13] Loaded {len(messages)} messages for tenant={tenant_id}, session={session_id}"
)
return messages
async def append_message(
self,
tenant_id: str,
session_id: str,
role: str,
content: str,
) -> ChatMessage:
"""
[AC-AISVC-13] Append a message to the session history.
Message is scoped by tenant_id for isolation.
"""
message = ChatMessage(
tenant_id=tenant_id,
session_id=session_id,
role=role,
content=content,
)
self._session.add(message)
await self._session.flush()
logger.info(
f"[AC-AISVC-13] Appended message: tenant={tenant_id}, session={session_id}, role={role}"
)
return message
async def append_messages(
self,
tenant_id: str,
session_id: str,
messages: list[dict[str, str]],
) -> list[ChatMessage]:
"""
[AC-AISVC-13] Append multiple messages to the session history.
Used for batch insertion of conversation turns.
"""
chat_messages = []
for msg in messages:
message = ChatMessage(
tenant_id=tenant_id,
session_id=session_id,
role=msg["role"],
content=msg["content"],
)
self._session.add(message)
chat_messages.append(message)
await self._session.flush()
logger.info(
f"[AC-AISVC-13] Appended {len(chat_messages)} messages for tenant={tenant_id}, session={session_id}"
)
return chat_messages
async def clear_history(self, tenant_id: str, session_id: str) -> int:
"""
[AC-AISVC-13] Clear all messages for a session.
Only affects messages within the tenant's scope.
"""
stmt = select(ChatMessage).where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.session_id == session_id,
)
result = await self._session.execute(stmt)
messages = result.scalars().all()
count = 0
for message in messages:
await self._session.delete(message)
count += 1
await self._session.flush()
logger.info(
f"[AC-AISVC-13] Cleared {count} messages for tenant={tenant_id}, session={session_id}"
)
return count