171 lines
4.9 KiB
Python
171 lines
4.9 KiB
Python
"""
|
|
Memory service for AI Service.
|
|
[AC-AISVC-13] Session-based memory management with tenant isolation.
|
|
"""
|
|
|
|
import logging
|
|
from collections.abc import Sequence
|
|
|
|
from sqlalchemy import select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlmodel import col
|
|
|
|
from app.models.entities import ChatMessage, ChatSession
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class MemoryService:
|
|
"""
|
|
[AC-AISVC-13] Memory service for session-based conversation history.
|
|
All operations are scoped by (tenant_id, session_id) for multi-tenant isolation.
|
|
"""
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
self._session = session
|
|
|
|
async def get_or_create_session(
|
|
self,
|
|
tenant_id: str,
|
|
session_id: str,
|
|
channel_type: str | None = None,
|
|
metadata: dict | None = None,
|
|
) -> ChatSession:
|
|
"""
|
|
[AC-AISVC-13] Get existing session or create a new one.
|
|
Ensures tenant isolation by querying with tenant_id.
|
|
"""
|
|
stmt = select(ChatSession).where(
|
|
ChatSession.tenant_id == tenant_id,
|
|
ChatSession.session_id == session_id,
|
|
)
|
|
result = await self._session.execute(stmt)
|
|
existing_session = result.scalar_one_or_none()
|
|
|
|
if existing_session:
|
|
logger.info(
|
|
f"[AC-AISVC-13] Found existing session: tenant={tenant_id}, session={session_id}"
|
|
)
|
|
return existing_session
|
|
|
|
new_session = ChatSession(
|
|
tenant_id=tenant_id,
|
|
session_id=session_id,
|
|
channel_type=channel_type,
|
|
metadata_=metadata,
|
|
)
|
|
self._session.add(new_session)
|
|
await self._session.flush()
|
|
|
|
logger.info(
|
|
f"[AC-AISVC-13] Created new session: tenant={tenant_id}, session={session_id}"
|
|
)
|
|
return new_session
|
|
|
|
async def load_history(
|
|
self,
|
|
tenant_id: str,
|
|
session_id: str,
|
|
limit: int | None = None,
|
|
) -> Sequence[ChatMessage]:
|
|
"""
|
|
[AC-AISVC-13] Load conversation history for a session.
|
|
All queries are filtered by tenant_id to ensure isolation.
|
|
"""
|
|
stmt = (
|
|
select(ChatMessage)
|
|
.where(
|
|
ChatMessage.tenant_id == tenant_id,
|
|
ChatMessage.session_id == session_id,
|
|
)
|
|
.order_by(col(ChatMessage.created_at).asc())
|
|
)
|
|
|
|
if limit:
|
|
stmt = stmt.limit(limit)
|
|
|
|
result = await self._session.execute(stmt)
|
|
messages = result.scalars().all()
|
|
|
|
logger.info(
|
|
f"[AC-AISVC-13] Loaded {len(messages)} messages for tenant={tenant_id}, session={session_id}"
|
|
)
|
|
return messages
|
|
|
|
async def append_message(
|
|
self,
|
|
tenant_id: str,
|
|
session_id: str,
|
|
role: str,
|
|
content: str,
|
|
) -> ChatMessage:
|
|
"""
|
|
[AC-AISVC-13] Append a message to the session history.
|
|
Message is scoped by tenant_id for isolation.
|
|
"""
|
|
message = ChatMessage(
|
|
tenant_id=tenant_id,
|
|
session_id=session_id,
|
|
role=role,
|
|
content=content,
|
|
)
|
|
self._session.add(message)
|
|
await self._session.flush()
|
|
|
|
logger.info(
|
|
f"[AC-AISVC-13] Appended message: tenant={tenant_id}, session={session_id}, role={role}"
|
|
)
|
|
return message
|
|
|
|
async def append_messages(
|
|
self,
|
|
tenant_id: str,
|
|
session_id: str,
|
|
messages: list[dict[str, str]],
|
|
) -> list[ChatMessage]:
|
|
"""
|
|
[AC-AISVC-13] Append multiple messages to the session history.
|
|
Used for batch insertion of conversation turns.
|
|
"""
|
|
chat_messages = []
|
|
for msg in messages:
|
|
message = ChatMessage(
|
|
tenant_id=tenant_id,
|
|
session_id=session_id,
|
|
role=msg["role"],
|
|
content=msg["content"],
|
|
)
|
|
self._session.add(message)
|
|
chat_messages.append(message)
|
|
|
|
await self._session.flush()
|
|
|
|
logger.info(
|
|
f"[AC-AISVC-13] Appended {len(chat_messages)} messages for tenant={tenant_id}, session={session_id}"
|
|
)
|
|
return chat_messages
|
|
|
|
async def clear_history(self, tenant_id: str, session_id: str) -> int:
|
|
"""
|
|
[AC-AISVC-13] Clear all messages for a session.
|
|
Only affects messages within the tenant's scope.
|
|
"""
|
|
stmt = select(ChatMessage).where(
|
|
ChatMessage.tenant_id == tenant_id,
|
|
ChatMessage.session_id == session_id,
|
|
)
|
|
result = await self._session.execute(stmt)
|
|
messages = result.scalars().all()
|
|
|
|
count = 0
|
|
for message in messages:
|
|
await self._session.delete(message)
|
|
count += 1
|
|
|
|
await self._session.flush()
|
|
|
|
logger.info(
|
|
f"[AC-AISVC-13] Cleared {count} messages for tenant={tenant_id}, session={session_id}"
|
|
)
|
|
return count
|