279 lines
7.9 KiB
Python
279 lines
7.9 KiB
Python
"""
|
|
Knowledge Base service for AI Service.
|
|
[AC-ASA-01, AC-ASA-02, AC-ASA-08] KB management with document upload, indexing, and listing.
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
import uuid
|
|
from datetime import datetime
|
|
from typing import Sequence
|
|
|
|
from sqlalchemy import select, func
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlmodel import col
|
|
|
|
from app.models.entities import (
|
|
Document,
|
|
DocumentStatus,
|
|
IndexJob,
|
|
IndexJobStatus,
|
|
KnowledgeBase,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class KBService:
|
|
"""
|
|
[AC-ASA-01, AC-ASA-02, AC-ASA-08] Knowledge Base service.
|
|
Handles document upload, indexing jobs, and document listing.
|
|
"""
|
|
|
|
def __init__(self, session: AsyncSession, upload_dir: str = "./uploads"):
|
|
self._session = session
|
|
self._upload_dir = upload_dir
|
|
os.makedirs(upload_dir, exist_ok=True)
|
|
|
|
async def get_or_create_kb(
|
|
self,
|
|
tenant_id: str,
|
|
kb_id: str | None = None,
|
|
name: str = "Default KB",
|
|
) -> KnowledgeBase:
|
|
"""
|
|
Get existing KB or create default one.
|
|
"""
|
|
if kb_id:
|
|
stmt = select(KnowledgeBase).where(
|
|
KnowledgeBase.tenant_id == tenant_id,
|
|
KnowledgeBase.id == uuid.UUID(kb_id),
|
|
)
|
|
result = await self._session.execute(stmt)
|
|
existing_kb = result.scalar_one_or_none()
|
|
if existing_kb:
|
|
return existing_kb
|
|
|
|
stmt = select(KnowledgeBase).where(
|
|
KnowledgeBase.tenant_id == tenant_id,
|
|
).limit(1)
|
|
result = await self._session.execute(stmt)
|
|
existing_kb = result.scalar_one_or_none()
|
|
|
|
if existing_kb:
|
|
return existing_kb
|
|
|
|
new_kb = KnowledgeBase(
|
|
tenant_id=tenant_id,
|
|
name=name,
|
|
)
|
|
self._session.add(new_kb)
|
|
await self._session.flush()
|
|
|
|
logger.info(f"[AC-ASA-01] Created knowledge base: tenant={tenant_id}, kb_id={new_kb.id}")
|
|
return new_kb
|
|
|
|
async def upload_document(
|
|
self,
|
|
tenant_id: str,
|
|
kb_id: str,
|
|
file_name: str,
|
|
file_content: bytes,
|
|
file_type: str | None = None,
|
|
) -> tuple[Document, IndexJob]:
|
|
"""
|
|
[AC-ASA-01] Upload document and create indexing job.
|
|
"""
|
|
doc_id = uuid.uuid4()
|
|
file_path = os.path.join(self._upload_dir, f"{tenant_id}_{doc_id}_{file_name}")
|
|
|
|
with open(file_path, "wb") as f:
|
|
f.write(file_content)
|
|
|
|
document = Document(
|
|
id=doc_id,
|
|
tenant_id=tenant_id,
|
|
kb_id=kb_id,
|
|
file_name=file_name,
|
|
file_path=file_path,
|
|
file_size=len(file_content),
|
|
file_type=file_type,
|
|
status=DocumentStatus.PENDING.value,
|
|
)
|
|
self._session.add(document)
|
|
|
|
job = IndexJob(
|
|
tenant_id=tenant_id,
|
|
doc_id=doc_id,
|
|
status=IndexJobStatus.PENDING.value,
|
|
progress=0,
|
|
)
|
|
self._session.add(job)
|
|
|
|
await self._session.flush()
|
|
|
|
logger.info(
|
|
f"[AC-ASA-01] Uploaded document: tenant={tenant_id}, doc_id={doc_id}, "
|
|
f"file_name={file_name}, size={len(file_content)}"
|
|
)
|
|
|
|
return document, job
|
|
|
|
async def list_documents(
|
|
self,
|
|
tenant_id: str,
|
|
kb_id: str | None = None,
|
|
status: str | None = None,
|
|
page: int = 1,
|
|
page_size: int = 20,
|
|
) -> tuple[Sequence[Document], int]:
|
|
"""
|
|
[AC-ASA-08] List documents with filtering and pagination.
|
|
"""
|
|
stmt = select(Document).where(Document.tenant_id == tenant_id)
|
|
|
|
if kb_id:
|
|
stmt = stmt.where(Document.kb_id == kb_id)
|
|
if status:
|
|
stmt = stmt.where(Document.status == status)
|
|
|
|
count_stmt = select(func.count()).select_from(stmt.subquery())
|
|
total_result = await self._session.execute(count_stmt)
|
|
total = total_result.scalar() or 0
|
|
|
|
stmt = stmt.order_by(col(Document.created_at).desc())
|
|
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
|
|
|
|
result = await self._session.execute(stmt)
|
|
documents = result.scalars().all()
|
|
|
|
logger.info(
|
|
f"[AC-ASA-08] Listed documents: tenant={tenant_id}, "
|
|
f"kb_id={kb_id}, status={status}, total={total}"
|
|
)
|
|
|
|
return documents, total
|
|
|
|
async def get_document(
|
|
self,
|
|
tenant_id: str,
|
|
doc_id: str,
|
|
) -> Document | None:
|
|
"""
|
|
Get document by ID.
|
|
"""
|
|
stmt = select(Document).where(
|
|
Document.tenant_id == tenant_id,
|
|
Document.id == uuid.UUID(doc_id),
|
|
)
|
|
result = await self._session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def get_index_job(
|
|
self,
|
|
tenant_id: str,
|
|
job_id: str,
|
|
) -> IndexJob | None:
|
|
"""
|
|
[AC-ASA-02] Get index job status.
|
|
"""
|
|
stmt = select(IndexJob).where(
|
|
IndexJob.tenant_id == tenant_id,
|
|
IndexJob.id == uuid.UUID(job_id),
|
|
)
|
|
result = await self._session.execute(stmt)
|
|
job = result.scalar_one_or_none()
|
|
|
|
if job:
|
|
logger.info(
|
|
f"[AC-ASA-02] Got job status: tenant={tenant_id}, "
|
|
f"job_id={job_id}, status={job.status}, progress={job.progress}"
|
|
)
|
|
|
|
return job
|
|
|
|
async def get_index_job_by_doc(
|
|
self,
|
|
tenant_id: str,
|
|
doc_id: str,
|
|
) -> IndexJob | None:
|
|
"""
|
|
Get index job by document ID.
|
|
"""
|
|
stmt = select(IndexJob).where(
|
|
IndexJob.tenant_id == tenant_id,
|
|
IndexJob.doc_id == uuid.UUID(doc_id),
|
|
).order_by(col(IndexJob.created_at).desc())
|
|
result = await self._session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
|
|
async def update_job_status(
|
|
self,
|
|
tenant_id: str,
|
|
job_id: str,
|
|
status: str,
|
|
progress: int | None = None,
|
|
error_msg: str | None = None,
|
|
) -> IndexJob | None:
|
|
"""
|
|
Update index job status.
|
|
"""
|
|
stmt = select(IndexJob).where(
|
|
IndexJob.tenant_id == tenant_id,
|
|
IndexJob.id == uuid.UUID(job_id),
|
|
)
|
|
result = await self._session.execute(stmt)
|
|
job = result.scalar_one_or_none()
|
|
|
|
if job:
|
|
job.status = status
|
|
job.updated_at = datetime.utcnow()
|
|
if progress is not None:
|
|
job.progress = progress
|
|
if error_msg is not None:
|
|
job.error_msg = error_msg
|
|
await self._session.flush()
|
|
|
|
if job.doc_id:
|
|
doc_stmt = select(Document).where(
|
|
Document.tenant_id == tenant_id,
|
|
Document.id == job.doc_id,
|
|
)
|
|
doc_result = await self._session.execute(doc_stmt)
|
|
doc = doc_result.scalar_one_or_none()
|
|
if doc:
|
|
doc.status = status
|
|
doc.updated_at = datetime.utcnow()
|
|
if error_msg:
|
|
doc.error_msg = error_msg
|
|
await self._session.flush()
|
|
|
|
return job
|
|
|
|
async def delete_document(
|
|
self,
|
|
tenant_id: str,
|
|
doc_id: str,
|
|
) -> bool:
|
|
"""
|
|
Delete document and associated files.
|
|
"""
|
|
stmt = select(Document).where(
|
|
Document.tenant_id == tenant_id,
|
|
Document.id == uuid.UUID(doc_id),
|
|
)
|
|
result = await self._session.execute(stmt)
|
|
document = result.scalar_one_or_none()
|
|
|
|
if not document:
|
|
return False
|
|
|
|
if document.file_path and os.path.exists(document.file_path):
|
|
os.remove(document.file_path)
|
|
|
|
await self._session.delete(document)
|
|
await self._session.flush()
|
|
|
|
logger.info(f"[AC-ASA-08] Deleted document: tenant={tenant_id}, doc_id={doc_id}")
|
|
return True
|