ai-robot-core/ai-service/app/services/kb.py

303 lines
8.8 KiB
Python
Raw Normal View History

"""
Knowledge Base service for AI Service.
[AC-ASA-01, AC-ASA-02, AC-ASA-08] KB management with document upload, indexing, and listing.
"""
import logging
import os
import uuid
from collections.abc import Sequence
from datetime import datetime
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col
from app.models.entities import (
Document,
DocumentStatus,
IndexJob,
IndexJobStatus,
KnowledgeBase,
)
logger = logging.getLogger(__name__)
class KBService:
"""
[AC-ASA-01, AC-ASA-02, AC-ASA-08] Knowledge Base service.
Handles document upload, indexing jobs, and document listing.
"""
def __init__(self, session: AsyncSession, upload_dir: str = "./uploads"):
self._session = session
self._upload_dir = upload_dir
os.makedirs(upload_dir, exist_ok=True)
async def get_or_create_kb(
self,
tenant_id: str,
kb_id: str | None = None,
name: str = "Default KB",
) -> KnowledgeBase:
"""
Get existing KB or create default one.
"""
if kb_id:
try:
stmt = select(KnowledgeBase).where(
KnowledgeBase.tenant_id == tenant_id,
KnowledgeBase.id == uuid.UUID(kb_id),
)
result = await self._session.execute(stmt)
existing_kb = result.scalar_one_or_none()
if existing_kb:
return existing_kb
except ValueError:
pass
stmt = select(KnowledgeBase).where(
KnowledgeBase.tenant_id == tenant_id,
).limit(1)
result = await self._session.execute(stmt)
existing_kb = result.scalar_one_or_none()
if existing_kb:
return existing_kb
new_kb = KnowledgeBase(
tenant_id=tenant_id,
name=name,
)
self._session.add(new_kb)
await self._session.flush()
logger.info(f"[AC-ASA-01] Created knowledge base: tenant={tenant_id}, kb_id={new_kb.id}")
return new_kb
async def upload_document(
self,
tenant_id: str,
kb_id: str,
file_name: str,
file_content: bytes,
file_type: str | None = None,
metadata: dict | None = None,
) -> tuple[Document, IndexJob]:
"""
[AC-ASA-01] Upload document and create indexing job.
"""
import urllib.parse
doc_id = uuid.uuid4()
# 安全处理文件名:使用 UUID 作为存储文件名,保留原始文件名在数据库中
file_ext = os.path.splitext(file_name)[1] if file_name else ""
safe_file_name = f"{tenant_id}_{doc_id}{file_ext}"
file_path = os.path.join(self._upload_dir, safe_file_name)
with open(file_path, "wb") as f:
f.write(file_content)
document = Document(
id=doc_id,
tenant_id=tenant_id,
kb_id=kb_id,
file_name=file_name,
file_path=file_path,
file_size=len(file_content),
file_type=file_type,
status=DocumentStatus.PENDING.value,
doc_metadata=metadata,
)
self._session.add(document)
job = IndexJob(
tenant_id=tenant_id,
doc_id=doc_id,
status=IndexJobStatus.PENDING.value,
progress=0,
)
self._session.add(job)
await self._session.flush()
logger.info(
f"[AC-ASA-01] Uploaded document: tenant={tenant_id}, doc_id={doc_id}, "
f"file_name={file_name}, size={len(file_content)}"
)
return document, job
async def list_documents(
self,
tenant_id: str,
kb_id: str | None = None,
status: str | None = None,
page: int = 1,
page_size: int = 20,
) -> tuple[Sequence[Document], int]:
"""
[AC-ASA-08] List documents with filtering and pagination.
"""
stmt = select(Document).where(Document.tenant_id == tenant_id)
if kb_id:
stmt = stmt.where(Document.kb_id == kb_id)
if status:
stmt = stmt.where(Document.status == status)
count_stmt = select(func.count()).select_from(stmt.subquery())
total_result = await self._session.execute(count_stmt)
total = total_result.scalar() or 0
stmt = stmt.order_by(col(Document.created_at).desc())
stmt = stmt.offset((page - 1) * page_size).limit(page_size)
result = await self._session.execute(stmt)
documents = result.scalars().all()
logger.info(
f"[AC-ASA-08] Listed documents: tenant={tenant_id}, "
f"kb_id={kb_id}, status={status}, total={total}"
)
return documents, total
async def get_document(
self,
tenant_id: str,
doc_id: str,
) -> Document | None:
"""
Get document by ID.
"""
stmt = select(Document).where(
Document.tenant_id == tenant_id,
Document.id == uuid.UUID(doc_id),
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def get_index_job(
self,
tenant_id: str,
job_id: str,
) -> IndexJob | None:
"""
[AC-ASA-02] Get index job status.
"""
stmt = select(IndexJob).where(
IndexJob.tenant_id == tenant_id,
IndexJob.id == uuid.UUID(job_id),
)
result = await self._session.execute(stmt)
job = result.scalar_one_or_none()
if job:
logger.info(
f"[AC-ASA-02] Got job status: tenant={tenant_id}, "
f"job_id={job_id}, status={job.status}, progress={job.progress}"
)
return job
async def get_index_job_by_doc(
self,
tenant_id: str,
doc_id: str,
) -> IndexJob | None:
"""
Get index job by document ID.
"""
stmt = select(IndexJob).where(
IndexJob.tenant_id == tenant_id,
IndexJob.doc_id == uuid.UUID(doc_id),
).order_by(col(IndexJob.created_at).desc())
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def update_job_status(
self,
tenant_id: str,
job_id: str,
status: str,
progress: int | None = None,
error_msg: str | None = None,
) -> IndexJob | None:
"""
Update index job status.
"""
stmt = select(IndexJob).where(
IndexJob.tenant_id == tenant_id,
IndexJob.id == uuid.UUID(job_id),
)
result = await self._session.execute(stmt)
job = result.scalar_one_or_none()
if job:
job.status = status
job.updated_at = datetime.utcnow()
if progress is not None:
job.progress = progress
if error_msg is not None:
job.error_msg = error_msg
await self._session.flush()
if job.doc_id:
doc_stmt = select(Document).where(
Document.tenant_id == tenant_id,
Document.id == job.doc_id,
)
doc_result = await self._session.execute(doc_stmt)
doc = doc_result.scalar_one_or_none()
if doc:
doc.status = status
doc.updated_at = datetime.utcnow()
if error_msg:
doc.error_msg = error_msg
await self._session.flush()
return job
async def delete_document(
self,
tenant_id: str,
doc_id: str,
) -> bool:
"""
Delete document and associated files.
"""
stmt = select(Document).where(
Document.tenant_id == tenant_id,
Document.id == uuid.UUID(doc_id),
)
result = await self._session.execute(stmt)
document = result.scalar_one_or_none()
if not document:
return False
if document.file_path and os.path.exists(document.file_path):
os.remove(document.file_path)
await self._session.delete(document)
await self._session.flush()
logger.info(f"[AC-ASA-08] Deleted document: tenant={tenant_id}, doc_id={doc_id}")
return True
async def list_knowledge_bases(
self,
tenant_id: str,
) -> Sequence[KnowledgeBase]:
"""
List all knowledge bases for a tenant.
"""
stmt = select(KnowledgeBase).where(
KnowledgeBase.tenant_id == tenant_id
).order_by(col(KnowledgeBase.created_at).desc())
result = await self._session.execute(stmt)
return result.scalars().all()