ai-robot-core/ai-service/app/services/knowledge_base_service.py

304 lines
8.5 KiB
Python

"""
Knowledge Base CRUD service for AI Service.
[AC-AISVC-59~AC-AISVC-64] Multi-knowledge-base management with Qdrant Collection integration.
"""
import logging
import uuid
from collections.abc import Sequence
from datetime import datetime
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlmodel import col
from app.core.qdrant_client import get_qdrant_client
from app.models.entities import (
Document,
KBType,
KnowledgeBase,
KnowledgeBaseCreate,
KnowledgeBaseUpdate,
)
logger = logging.getLogger(__name__)
class KnowledgeBaseService:
"""
[AC-AISVC-59~AC-AISVC-64] Knowledge Base CRUD service.
Handles KB creation with Qdrant Collection initialization,
KB updates, deletion with Collection cleanup, and listing.
"""
def __init__(self, session: AsyncSession):
self._session = session
async def create_knowledge_base(
self,
tenant_id: str,
kb_create: KnowledgeBaseCreate,
) -> KnowledgeBase:
"""
[AC-AISVC-59] Create a new knowledge base.
Initializes corresponding Qdrant Collection.
Args:
tenant_id: Tenant identifier
kb_create: Knowledge base creation data
Returns:
Created KnowledgeBase entity
"""
kb = KnowledgeBase(
tenant_id=tenant_id,
name=kb_create.name,
kb_type=kb_create.kb_type,
description=kb_create.description,
priority=kb_create.priority,
is_enabled=True,
doc_count=0,
)
self._session.add(kb)
await self._session.flush()
qdrant = await get_qdrant_client()
await qdrant.ensure_kb_collection_exists(tenant_id, str(kb.id))
logger.info(
f"[AC-AISVC-59] Created knowledge base: tenant={tenant_id}, "
f"kb_id={kb.id}, name={kb.name}, type={kb.kb_type}"
)
return kb
async def get_knowledge_base(
self,
tenant_id: str,
kb_id: str,
) -> KnowledgeBase | None:
"""
Get a knowledge base by ID.
Args:
tenant_id: Tenant identifier
kb_id: Knowledge base ID
Returns:
KnowledgeBase entity or None
"""
try:
stmt = select(KnowledgeBase).where(
KnowledgeBase.tenant_id == tenant_id,
KnowledgeBase.id == uuid.UUID(kb_id),
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
except ValueError:
return None
async def list_knowledge_bases(
self,
tenant_id: str,
kb_type: str | None = None,
is_enabled: bool | None = None,
) -> Sequence[KnowledgeBase]:
"""
[AC-AISVC-60] List knowledge bases for a tenant.
Args:
tenant_id: Tenant identifier
kb_type: Filter by knowledge base type (optional)
is_enabled: Filter by enabled status (optional)
Returns:
List of KnowledgeBase entities
"""
stmt = select(KnowledgeBase).where(
KnowledgeBase.tenant_id == tenant_id
)
if kb_type:
stmt = stmt.where(KnowledgeBase.kb_type == kb_type)
if is_enabled is not None:
stmt = stmt.where(KnowledgeBase.is_enabled == is_enabled)
stmt = stmt.order_by(
col(KnowledgeBase.priority).desc(),
col(KnowledgeBase.created_at).desc()
)
result = await self._session.execute(stmt)
return result.scalars().all()
async def update_knowledge_base(
self,
tenant_id: str,
kb_id: str,
kb_update: KnowledgeBaseUpdate,
) -> KnowledgeBase | None:
"""
[AC-AISVC-61] Update a knowledge base.
Args:
tenant_id: Tenant identifier
kb_id: Knowledge base ID
kb_update: Update data
Returns:
Updated KnowledgeBase entity or None
"""
kb = await self.get_knowledge_base(tenant_id, kb_id)
if not kb:
return None
update_data = kb_update.model_dump(exclude_unset=True)
for key, value in update_data.items():
setattr(kb, key, value)
kb.updated_at = datetime.utcnow()
await self._session.flush()
logger.info(
f"[AC-AISVC-61] Updated knowledge base: tenant={tenant_id}, "
f"kb_id={kb_id}, fields={list(update_data.keys())}"
)
return kb
async def delete_knowledge_base(
self,
tenant_id: str,
kb_id: str,
) -> bool:
"""
[AC-AISVC-62] Delete a knowledge base.
Also deletes associated documents and Qdrant Collection.
Args:
tenant_id: Tenant identifier
kb_id: Knowledge base ID
Returns:
True if deleted successfully
"""
kb = await self.get_knowledge_base(tenant_id, kb_id)
if not kb:
return False
doc_stmt = select(Document).where(
Document.tenant_id == tenant_id,
Document.kb_id == kb_id,
)
doc_result = await self._session.execute(doc_stmt)
documents = doc_result.scalars().all()
for doc in documents:
await self._session.delete(doc)
await self._session.delete(kb)
await self._session.flush()
qdrant = await get_qdrant_client()
await qdrant.delete_kb_collection(tenant_id, kb_id)
logger.info(
f"[AC-AISVC-62] Deleted knowledge base: tenant={tenant_id}, "
f"kb_id={kb_id}, docs_deleted={len(documents)}"
)
return True
async def update_doc_count(
self,
tenant_id: str,
kb_id: str,
delta: int = 1,
) -> bool:
"""
Update document count for a knowledge base.
Args:
tenant_id: Tenant identifier
kb_id: Knowledge base ID
delta: Change in document count (positive or negative)
Returns:
True if updated successfully
"""
try:
kb = await self.get_knowledge_base(tenant_id, kb_id)
if kb:
kb.doc_count = max(0, kb.doc_count + delta)
kb.updated_at = datetime.utcnow()
await self._session.flush()
return True
except Exception as e:
logger.error(f"Error updating doc count: {e}")
return False
async def recalculate_doc_counts(
self,
tenant_id: str,
) -> dict[str, int]:
"""
Recalculate document counts for all knowledge bases of a tenant.
Args:
tenant_id: Tenant identifier
Returns:
Dictionary mapping kb_id to doc_count
"""
count_stmt = (
select(Document.kb_id, func.count(Document.id).label("count"))
.where(Document.tenant_id == tenant_id)
.group_by(Document.kb_id)
)
result = await self._session.execute(count_stmt)
counts = {row.kb_id: row.count for row in result}
kb_stmt = select(KnowledgeBase).where(KnowledgeBase.tenant_id == tenant_id)
kb_result = await self._session.execute(kb_stmt)
knowledge_bases = kb_result.scalars().all()
for kb in knowledge_bases:
kb_id_str = str(kb.id)
kb.doc_count = counts.get(kb_id_str, 0)
kb.updated_at = datetime.utcnow()
await self._session.flush()
return {str(kb.id): kb.doc_count for kb in knowledge_bases}
async def get_or_create_default_kb(
self,
tenant_id: str,
) -> KnowledgeBase:
"""
Get or create the default knowledge base for a tenant.
This is used for backward compatibility with existing data.
Args:
tenant_id: Tenant identifier
Returns:
Default KnowledgeBase entity
"""
stmt = select(KnowledgeBase).where(
KnowledgeBase.tenant_id == tenant_id,
).limit(1)
result = await self._session.execute(stmt)
existing_kb = result.scalar_one_or_none()
if existing_kb:
return existing_kb
kb_create = KnowledgeBaseCreate(
name="Default Knowledge Base",
kb_type=KBType.GENERAL.value,
description="Default knowledge base for backward compatibility",
priority=0,
)
return await self.create_knowledge_base(tenant_id, kb_create)