304 lines
8.5 KiB
Python
304 lines
8.5 KiB
Python
"""
|
|
Knowledge Base CRUD service for AI Service.
|
|
[AC-AISVC-59~AC-AISVC-64] Multi-knowledge-base management with Qdrant Collection integration.
|
|
"""
|
|
|
|
import logging
|
|
import uuid
|
|
from collections.abc import Sequence
|
|
from datetime import datetime
|
|
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
from sqlmodel import col
|
|
|
|
from app.core.qdrant_client import get_qdrant_client
|
|
from app.models.entities import (
|
|
Document,
|
|
KBType,
|
|
KnowledgeBase,
|
|
KnowledgeBaseCreate,
|
|
KnowledgeBaseUpdate,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class KnowledgeBaseService:
|
|
"""
|
|
[AC-AISVC-59~AC-AISVC-64] Knowledge Base CRUD service.
|
|
Handles KB creation with Qdrant Collection initialization,
|
|
KB updates, deletion with Collection cleanup, and listing.
|
|
"""
|
|
|
|
def __init__(self, session: AsyncSession):
|
|
self._session = session
|
|
|
|
async def create_knowledge_base(
|
|
self,
|
|
tenant_id: str,
|
|
kb_create: KnowledgeBaseCreate,
|
|
) -> KnowledgeBase:
|
|
"""
|
|
[AC-AISVC-59] Create a new knowledge base.
|
|
Initializes corresponding Qdrant Collection.
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
kb_create: Knowledge base creation data
|
|
|
|
Returns:
|
|
Created KnowledgeBase entity
|
|
"""
|
|
kb = KnowledgeBase(
|
|
tenant_id=tenant_id,
|
|
name=kb_create.name,
|
|
kb_type=kb_create.kb_type,
|
|
description=kb_create.description,
|
|
priority=kb_create.priority,
|
|
is_enabled=True,
|
|
doc_count=0,
|
|
)
|
|
self._session.add(kb)
|
|
await self._session.flush()
|
|
|
|
qdrant = await get_qdrant_client()
|
|
await qdrant.ensure_kb_collection_exists(tenant_id, str(kb.id))
|
|
|
|
logger.info(
|
|
f"[AC-AISVC-59] Created knowledge base: tenant={tenant_id}, "
|
|
f"kb_id={kb.id}, name={kb.name}, type={kb.kb_type}"
|
|
)
|
|
|
|
return kb
|
|
|
|
async def get_knowledge_base(
|
|
self,
|
|
tenant_id: str,
|
|
kb_id: str,
|
|
) -> KnowledgeBase | None:
|
|
"""
|
|
Get a knowledge base by ID.
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
kb_id: Knowledge base ID
|
|
|
|
Returns:
|
|
KnowledgeBase entity or None
|
|
"""
|
|
try:
|
|
stmt = select(KnowledgeBase).where(
|
|
KnowledgeBase.tenant_id == tenant_id,
|
|
KnowledgeBase.id == uuid.UUID(kb_id),
|
|
)
|
|
result = await self._session.execute(stmt)
|
|
return result.scalar_one_or_none()
|
|
except ValueError:
|
|
return None
|
|
|
|
async def list_knowledge_bases(
|
|
self,
|
|
tenant_id: str,
|
|
kb_type: str | None = None,
|
|
is_enabled: bool | None = None,
|
|
) -> Sequence[KnowledgeBase]:
|
|
"""
|
|
[AC-AISVC-60] List knowledge bases for a tenant.
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
kb_type: Filter by knowledge base type (optional)
|
|
is_enabled: Filter by enabled status (optional)
|
|
|
|
Returns:
|
|
List of KnowledgeBase entities
|
|
"""
|
|
stmt = select(KnowledgeBase).where(
|
|
KnowledgeBase.tenant_id == tenant_id
|
|
)
|
|
|
|
if kb_type:
|
|
stmt = stmt.where(KnowledgeBase.kb_type == kb_type)
|
|
if is_enabled is not None:
|
|
stmt = stmt.where(KnowledgeBase.is_enabled == is_enabled)
|
|
|
|
stmt = stmt.order_by(
|
|
col(KnowledgeBase.priority).desc(),
|
|
col(KnowledgeBase.created_at).desc()
|
|
)
|
|
|
|
result = await self._session.execute(stmt)
|
|
return result.scalars().all()
|
|
|
|
async def update_knowledge_base(
|
|
self,
|
|
tenant_id: str,
|
|
kb_id: str,
|
|
kb_update: KnowledgeBaseUpdate,
|
|
) -> KnowledgeBase | None:
|
|
"""
|
|
[AC-AISVC-61] Update a knowledge base.
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
kb_id: Knowledge base ID
|
|
kb_update: Update data
|
|
|
|
Returns:
|
|
Updated KnowledgeBase entity or None
|
|
"""
|
|
kb = await self.get_knowledge_base(tenant_id, kb_id)
|
|
if not kb:
|
|
return None
|
|
|
|
update_data = kb_update.model_dump(exclude_unset=True)
|
|
for key, value in update_data.items():
|
|
setattr(kb, key, value)
|
|
|
|
kb.updated_at = datetime.utcnow()
|
|
await self._session.flush()
|
|
|
|
logger.info(
|
|
f"[AC-AISVC-61] Updated knowledge base: tenant={tenant_id}, "
|
|
f"kb_id={kb_id}, fields={list(update_data.keys())}"
|
|
)
|
|
|
|
return kb
|
|
|
|
async def delete_knowledge_base(
|
|
self,
|
|
tenant_id: str,
|
|
kb_id: str,
|
|
) -> bool:
|
|
"""
|
|
[AC-AISVC-62] Delete a knowledge base.
|
|
Also deletes associated documents and Qdrant Collection.
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
kb_id: Knowledge base ID
|
|
|
|
Returns:
|
|
True if deleted successfully
|
|
"""
|
|
kb = await self.get_knowledge_base(tenant_id, kb_id)
|
|
if not kb:
|
|
return False
|
|
|
|
doc_stmt = select(Document).where(
|
|
Document.tenant_id == tenant_id,
|
|
Document.kb_id == kb_id,
|
|
)
|
|
doc_result = await self._session.execute(doc_stmt)
|
|
documents = doc_result.scalars().all()
|
|
|
|
for doc in documents:
|
|
await self._session.delete(doc)
|
|
|
|
await self._session.delete(kb)
|
|
await self._session.flush()
|
|
|
|
qdrant = await get_qdrant_client()
|
|
await qdrant.delete_kb_collection(tenant_id, kb_id)
|
|
|
|
logger.info(
|
|
f"[AC-AISVC-62] Deleted knowledge base: tenant={tenant_id}, "
|
|
f"kb_id={kb_id}, docs_deleted={len(documents)}"
|
|
)
|
|
|
|
return True
|
|
|
|
async def update_doc_count(
|
|
self,
|
|
tenant_id: str,
|
|
kb_id: str,
|
|
delta: int = 1,
|
|
) -> bool:
|
|
"""
|
|
Update document count for a knowledge base.
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
kb_id: Knowledge base ID
|
|
delta: Change in document count (positive or negative)
|
|
|
|
Returns:
|
|
True if updated successfully
|
|
"""
|
|
try:
|
|
kb = await self.get_knowledge_base(tenant_id, kb_id)
|
|
if kb:
|
|
kb.doc_count = max(0, kb.doc_count + delta)
|
|
kb.updated_at = datetime.utcnow()
|
|
await self._session.flush()
|
|
return True
|
|
except Exception as e:
|
|
logger.error(f"Error updating doc count: {e}")
|
|
return False
|
|
|
|
async def recalculate_doc_counts(
|
|
self,
|
|
tenant_id: str,
|
|
) -> dict[str, int]:
|
|
"""
|
|
Recalculate document counts for all knowledge bases of a tenant.
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
|
|
Returns:
|
|
Dictionary mapping kb_id to doc_count
|
|
"""
|
|
count_stmt = (
|
|
select(Document.kb_id, func.count(Document.id).label("count"))
|
|
.where(Document.tenant_id == tenant_id)
|
|
.group_by(Document.kb_id)
|
|
)
|
|
result = await self._session.execute(count_stmt)
|
|
counts = {row.kb_id: row.count for row in result}
|
|
|
|
kb_stmt = select(KnowledgeBase).where(KnowledgeBase.tenant_id == tenant_id)
|
|
kb_result = await self._session.execute(kb_stmt)
|
|
knowledge_bases = kb_result.scalars().all()
|
|
|
|
for kb in knowledge_bases:
|
|
kb_id_str = str(kb.id)
|
|
kb.doc_count = counts.get(kb_id_str, 0)
|
|
kb.updated_at = datetime.utcnow()
|
|
|
|
await self._session.flush()
|
|
|
|
return {str(kb.id): kb.doc_count for kb in knowledge_bases}
|
|
|
|
async def get_or_create_default_kb(
|
|
self,
|
|
tenant_id: str,
|
|
) -> KnowledgeBase:
|
|
"""
|
|
Get or create the default knowledge base for a tenant.
|
|
This is used for backward compatibility with existing data.
|
|
|
|
Args:
|
|
tenant_id: Tenant identifier
|
|
|
|
Returns:
|
|
Default KnowledgeBase entity
|
|
"""
|
|
stmt = select(KnowledgeBase).where(
|
|
KnowledgeBase.tenant_id == tenant_id,
|
|
).limit(1)
|
|
result = await self._session.execute(stmt)
|
|
existing_kb = result.scalar_one_or_none()
|
|
|
|
if existing_kb:
|
|
return existing_kb
|
|
|
|
kb_create = KnowledgeBaseCreate(
|
|
name="Default Knowledge Base",
|
|
kb_type=KBType.GENERAL.value,
|
|
description="Default knowledge base for backward compatibility",
|
|
priority=0,
|
|
)
|
|
return await self.create_knowledge_base(tenant_id, kb_create)
|