""" Knowledge Base service for AI Service. [AC-ASA-01, AC-ASA-02, AC-ASA-08] KB management with document upload, indexing, and listing. """ import logging import os import uuid from datetime import datetime from typing import Sequence from sqlalchemy import select, func from sqlalchemy.ext.asyncio import AsyncSession from sqlmodel import col from app.models.entities import ( Document, DocumentStatus, IndexJob, IndexJobStatus, KnowledgeBase, ) logger = logging.getLogger(__name__) class KBService: """ [AC-ASA-01, AC-ASA-02, AC-ASA-08] Knowledge Base service. Handles document upload, indexing jobs, and document listing. """ def __init__(self, session: AsyncSession, upload_dir: str = "./uploads"): self._session = session self._upload_dir = upload_dir os.makedirs(upload_dir, exist_ok=True) async def get_or_create_kb( self, tenant_id: str, kb_id: str | None = None, name: str = "Default KB", ) -> KnowledgeBase: """ Get existing KB or create default one. """ if kb_id: try: stmt = select(KnowledgeBase).where( KnowledgeBase.tenant_id == tenant_id, KnowledgeBase.id == uuid.UUID(kb_id), ) result = await self._session.execute(stmt) existing_kb = result.scalar_one_or_none() if existing_kb: return existing_kb except ValueError: pass stmt = select(KnowledgeBase).where( KnowledgeBase.tenant_id == tenant_id, ).limit(1) result = await self._session.execute(stmt) existing_kb = result.scalar_one_or_none() if existing_kb: return existing_kb new_kb = KnowledgeBase( tenant_id=tenant_id, name=name, ) self._session.add(new_kb) await self._session.flush() logger.info(f"[AC-ASA-01] Created knowledge base: tenant={tenant_id}, kb_id={new_kb.id}") return new_kb async def upload_document( self, tenant_id: str, kb_id: str, file_name: str, file_content: bytes, file_type: str | None = None, ) -> tuple[Document, IndexJob]: """ [AC-ASA-01] Upload document and create indexing job. """ doc_id = uuid.uuid4() file_path = os.path.join(self._upload_dir, f"{tenant_id}_{doc_id}_{file_name}") with open(file_path, "wb") as f: f.write(file_content) document = Document( id=doc_id, tenant_id=tenant_id, kb_id=kb_id, file_name=file_name, file_path=file_path, file_size=len(file_content), file_type=file_type, status=DocumentStatus.PENDING.value, ) self._session.add(document) job = IndexJob( tenant_id=tenant_id, doc_id=doc_id, status=IndexJobStatus.PENDING.value, progress=0, ) self._session.add(job) await self._session.flush() logger.info( f"[AC-ASA-01] Uploaded document: tenant={tenant_id}, doc_id={doc_id}, " f"file_name={file_name}, size={len(file_content)}" ) return document, job async def list_documents( self, tenant_id: str, kb_id: str | None = None, status: str | None = None, page: int = 1, page_size: int = 20, ) -> tuple[Sequence[Document], int]: """ [AC-ASA-08] List documents with filtering and pagination. """ stmt = select(Document).where(Document.tenant_id == tenant_id) if kb_id: stmt = stmt.where(Document.kb_id == kb_id) if status: stmt = stmt.where(Document.status == status) count_stmt = select(func.count()).select_from(stmt.subquery()) total_result = await self._session.execute(count_stmt) total = total_result.scalar() or 0 stmt = stmt.order_by(col(Document.created_at).desc()) stmt = stmt.offset((page - 1) * page_size).limit(page_size) result = await self._session.execute(stmt) documents = result.scalars().all() logger.info( f"[AC-ASA-08] Listed documents: tenant={tenant_id}, " f"kb_id={kb_id}, status={status}, total={total}" ) return documents, total async def get_document( self, tenant_id: str, doc_id: str, ) -> Document | None: """ Get document by ID. """ stmt = select(Document).where( Document.tenant_id == tenant_id, Document.id == uuid.UUID(doc_id), ) result = await self._session.execute(stmt) return result.scalar_one_or_none() async def get_index_job( self, tenant_id: str, job_id: str, ) -> IndexJob | None: """ [AC-ASA-02] Get index job status. """ stmt = select(IndexJob).where( IndexJob.tenant_id == tenant_id, IndexJob.id == uuid.UUID(job_id), ) result = await self._session.execute(stmt) job = result.scalar_one_or_none() if job: logger.info( f"[AC-ASA-02] Got job status: tenant={tenant_id}, " f"job_id={job_id}, status={job.status}, progress={job.progress}" ) return job async def get_index_job_by_doc( self, tenant_id: str, doc_id: str, ) -> IndexJob | None: """ Get index job by document ID. """ stmt = select(IndexJob).where( IndexJob.tenant_id == tenant_id, IndexJob.doc_id == uuid.UUID(doc_id), ).order_by(col(IndexJob.created_at).desc()) result = await self._session.execute(stmt) return result.scalar_one_or_none() async def update_job_status( self, tenant_id: str, job_id: str, status: str, progress: int | None = None, error_msg: str | None = None, ) -> IndexJob | None: """ Update index job status. """ stmt = select(IndexJob).where( IndexJob.tenant_id == tenant_id, IndexJob.id == uuid.UUID(job_id), ) result = await self._session.execute(stmt) job = result.scalar_one_or_none() if job: job.status = status job.updated_at = datetime.utcnow() if progress is not None: job.progress = progress if error_msg is not None: job.error_msg = error_msg await self._session.flush() if job.doc_id: doc_stmt = select(Document).where( Document.tenant_id == tenant_id, Document.id == job.doc_id, ) doc_result = await self._session.execute(doc_stmt) doc = doc_result.scalar_one_or_none() if doc: doc.status = status doc.updated_at = datetime.utcnow() if error_msg: doc.error_msg = error_msg await self._session.flush() return job async def delete_document( self, tenant_id: str, doc_id: str, ) -> bool: """ Delete document and associated files. """ stmt = select(Document).where( Document.tenant_id == tenant_id, Document.id == uuid.UUID(doc_id), ) result = await self._session.execute(stmt) document = result.scalar_one_or_none() if not document: return False if document.file_path and os.path.exists(document.file_path): os.remove(document.file_path) await self._session.delete(document) await self._session.flush() logger.info(f"[AC-ASA-08] Deleted document: tenant={tenant_id}, doc_id={doc_id}") return True async def list_knowledge_bases( self, tenant_id: str, ) -> Sequence[KnowledgeBase]: """ List all knowledge bases for a tenant. """ stmt = select(KnowledgeBase).where( KnowledgeBase.tenant_id == tenant_id ).order_by(col(KnowledgeBase.created_at).desc()) result = await self._session.execute(stmt) return result.scalars().all()