ai-robot-core/ai-service/app/api/admin/kb.py

594 lines
20 KiB
Python

"""
Knowledge Base management endpoints.
[AC-ASA-01, AC-ASA-02, AC-ASA-08] Document upload, list, and index job status.
"""
import logging
import os
import uuid
from dataclasses import dataclass
from typing import Annotated, Optional
import tiktoken
from fastapi import APIRouter, BackgroundTasks, Depends, Query, UploadFile, File, Form
from fastapi.responses import JSONResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session
from app.core.exceptions import MissingTenantIdException
from app.core.tenant import get_tenant_id
from app.models import ErrorResponse
from app.models.entities import DocumentStatus, IndexJob, IndexJobStatus
from app.services.kb import KBService
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/kb", tags=["KB Management"])
@dataclass
class TextChunk:
"""Text chunk with metadata."""
text: str
start_token: int
end_token: int
page: int | None = None
source: str | None = None
def chunk_text_by_lines(
text: str,
min_line_length: int = 10,
source: str | None = None,
) -> list[TextChunk]:
"""
按行分块,每行作为一个独立的检索单元。
Args:
text: 要分块的文本
min_line_length: 最小行长度,低于此长度的行会被跳过
source: 来源文件路径(可选)
Returns:
分块列表,每个块对应一行文本
"""
lines = text.split('\n')
chunks: list[TextChunk] = []
for i, line in enumerate(lines):
line = line.strip()
if len(line) < min_line_length:
continue
chunks.append(TextChunk(
text=line,
start_token=i,
end_token=i + 1,
page=None,
source=source,
))
return chunks
def chunk_text_with_tiktoken(
text: str,
chunk_size: int = 512,
overlap: int = 100,
page: int | None = None,
source: str | None = None,
) -> list[TextChunk]:
"""
使用 tiktoken 按 token 数分块,支持重叠分块。
Args:
text: 要分块的文本
chunk_size: 每个块的最大 token 数
overlap: 块之间的重叠 token 数
page: 页码(可选)
source: 来源文件路径(可选)
Returns:
分块列表,每个块包含文本及起始/结束位置
"""
encoding = tiktoken.get_encoding("cl100k_base")
tokens = encoding.encode(text)
chunks: list[TextChunk] = []
start = 0
while start < len(tokens):
end = min(start + chunk_size, len(tokens))
chunk_tokens = tokens[start:end]
chunk_text = encoding.decode(chunk_tokens)
chunks.append(TextChunk(
text=chunk_text,
start_token=start,
end_token=end,
page=page,
source=source,
))
if end == len(tokens):
break
start += chunk_size - overlap
return chunks
def get_current_tenant_id() -> str:
"""Dependency to get current tenant ID or raise exception."""
tenant_id = get_tenant_id()
if not tenant_id:
raise MissingTenantIdException()
return tenant_id
@router.get(
"/knowledge-bases",
operation_id="listKnowledgeBases",
summary="Query knowledge base list",
description="Get list of knowledge bases for the current tenant.",
responses={
200: {"description": "Knowledge base list"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def list_knowledge_bases(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
) -> JSONResponse:
"""
List all knowledge bases for the current tenant.
"""
logger.info(f"Listing knowledge bases: tenant={tenant_id}")
kb_service = KBService(session)
knowledge_bases = await kb_service.list_knowledge_bases(tenant_id)
kb_ids = [str(kb.id) for kb in knowledge_bases]
doc_counts = {}
if kb_ids:
from sqlalchemy import func
from app.models.entities import Document
count_stmt = (
select(Document.kb_id, func.count(Document.id).label("count"))
.where(Document.tenant_id == tenant_id, Document.kb_id.in_(kb_ids))
.group_by(Document.kb_id)
)
count_result = await session.execute(count_stmt)
for row in count_result:
doc_counts[row.kb_id] = row.count
data = []
for kb in knowledge_bases:
kb_id_str = str(kb.id)
data.append({
"id": kb_id_str,
"name": kb.name,
"documentCount": doc_counts.get(kb_id_str, 0),
"createdAt": kb.created_at.isoformat() + "Z",
})
return JSONResponse(content={"data": data})
@router.get(
"/documents",
operation_id="listDocuments",
summary="Query document list",
description="[AC-ASA-08] Get list of documents with pagination and filtering.",
responses={
200: {"description": "Document list with pagination"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def list_documents(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
kb_id: Annotated[Optional[str], Query()] = None,
status: Annotated[Optional[str], Query()] = None,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
) -> JSONResponse:
"""
[AC-ASA-08] List documents with filtering and pagination.
"""
logger.info(
f"[AC-ASA-08] Listing documents: tenant={tenant_id}, kb_id={kb_id}, "
f"status={status}, page={page}, page_size={page_size}"
)
kb_service = KBService(session)
documents, total = await kb_service.list_documents(
tenant_id=tenant_id,
kb_id=kb_id,
status=status,
page=page,
page_size=page_size,
)
total_pages = (total + page_size - 1) // page_size if total > 0 else 0
data = []
for doc in documents:
job_stmt = select(IndexJob).where(
IndexJob.tenant_id == tenant_id,
IndexJob.doc_id == doc.id,
).order_by(IndexJob.created_at.desc())
job_result = await session.execute(job_stmt)
latest_job = job_result.scalar_one_or_none()
data.append({
"docId": str(doc.id),
"kbId": doc.kb_id,
"fileName": doc.file_name,
"status": doc.status,
"jobId": str(latest_job.id) if latest_job else None,
"createdAt": doc.created_at.isoformat() + "Z",
"updatedAt": doc.updated_at.isoformat() + "Z",
})
return JSONResponse(
content={
"data": data,
"pagination": {
"page": page,
"pageSize": page_size,
"total": total,
"totalPages": total_pages,
},
}
)
@router.post(
"/documents",
operation_id="uploadDocument",
summary="Upload/import document",
description="[AC-ASA-01] Upload document and trigger indexing job.",
responses={
202: {"description": "Accepted - async indexing job started"},
400: {"description": "Bad Request - unsupported format"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def upload_document(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
kb_id: str = Form(...),
) -> JSONResponse:
"""
[AC-ASA-01] Upload document and create indexing job.
[AC-AISVC-33, AC-AISVC-34, AC-AISVC-35, AC-AISVC-37] Support multiple document formats.
"""
from app.services.document import get_supported_document_formats, UnsupportedFormatError
from pathlib import Path
logger.info(
f"[AC-ASA-01] Uploading document: tenant={tenant_id}, "
f"kb_id={kb_id}, filename={file.filename}"
)
file_ext = Path(file.filename or "").suffix.lower()
supported_formats = get_supported_document_formats()
if file_ext and file_ext not in supported_formats:
return JSONResponse(
status_code=400,
content={
"code": "UNSUPPORTED_FORMAT",
"message": f"Unsupported file format: {file_ext}",
"details": {
"supported_formats": supported_formats,
},
},
)
kb_service = KBService(session)
kb = await kb_service.get_or_create_kb(tenant_id, kb_id)
file_content = await file.read()
document, job = await kb_service.upload_document(
tenant_id=tenant_id,
kb_id=str(kb.id),
file_name=file.filename or "unknown",
file_content=file_content,
file_type=file.content_type,
)
await session.commit()
background_tasks.add_task(
_index_document, tenant_id, str(job.id), str(document.id), file_content, file.filename
)
return JSONResponse(
status_code=202,
content={
"jobId": str(job.id),
"docId": str(document.id),
"status": job.status,
},
)
async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: bytes, filename: str | None = None):
"""
Background indexing task.
[AC-AISVC-33, AC-AISVC-34, AC-AISVC-35] Uses document parsing and pluggable embedding.
"""
from app.core.database import async_session_maker
from app.services.kb import KBService
from app.core.qdrant_client import get_qdrant_client
from app.services.embedding import get_embedding_provider
from app.services.document import parse_document, UnsupportedFormatError, DocumentParseException, PageText
from qdrant_client.models import PointStruct
import asyncio
import tempfile
from pathlib import Path
logger.info(f"[INDEX] Starting indexing: tenant={tenant_id}, job_id={job_id}, doc_id={doc_id}, filename={filename}")
await asyncio.sleep(1)
async with async_session_maker() as session:
kb_service = KBService(session)
try:
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=10
)
await session.commit()
parse_result = None
text = None
file_ext = Path(filename or "").suffix.lower()
logger.info(f"[INDEX] File extension: {file_ext}, content size: {len(content)} bytes")
text_extensions = {".txt", ".md", ".markdown", ".rst", ".log", ".json", ".xml", ".yaml", ".yml"}
if file_ext in text_extensions or not file_ext:
logger.info(f"[INDEX] Treating as text file, trying multiple encodings")
text = None
for encoding in ["utf-8", "gbk", "gb2312", "gb18030", "big5", "utf-16", "latin-1"]:
try:
text = content.decode(encoding)
logger.info(f"[INDEX] Successfully decoded with encoding: {encoding}")
break
except (UnicodeDecodeError, LookupError):
continue
if text is None:
text = content.decode("utf-8", errors="replace")
logger.warning(f"[INDEX] Failed to decode with known encodings, using utf-8 with replacement")
else:
logger.info(f"[INDEX] Binary file detected, will parse with document parser")
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=15
)
await session.commit()
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
logger.info(f"[INDEX] Temp file created: {tmp_path}")
try:
logger.info(f"[INDEX] Starting document parsing for {file_ext}...")
parse_result = parse_document(tmp_path)
text = parse_result.text
logger.info(
f"[INDEX] Parsed document SUCCESS: {filename}, "
f"chars={len(text)}, format={parse_result.metadata.get('format')}, "
f"pages={len(parse_result.pages) if parse_result.pages else 'N/A'}, "
f"metadata={parse_result.metadata}"
)
if len(text) < 100:
logger.warning(f"[INDEX] Parsed text is very short, preview: {text[:200]}")
except UnsupportedFormatError as e:
logger.error(f"[INDEX] UnsupportedFormatError: {e}")
text = content.decode("utf-8", errors="ignore")
except DocumentParseException as e:
logger.error(f"[INDEX] DocumentParseException: {e}, details={getattr(e, 'details', {})}")
text = content.decode("utf-8", errors="ignore")
except Exception as e:
logger.error(f"[INDEX] Unexpected parsing error: {type(e).__name__}: {e}")
text = content.decode("utf-8", errors="ignore")
finally:
Path(tmp_path).unlink(missing_ok=True)
logger.info(f"[INDEX] Temp file cleaned up")
logger.info(f"[INDEX] Final text length: {len(text)} chars")
if len(text) < 50:
logger.warning(f"[INDEX] Text too short, preview: {repr(text[:200])}")
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=20
)
await session.commit()
logger.info(f"[INDEX] Getting embedding provider...")
embedding_provider = await get_embedding_provider()
logger.info(f"[INDEX] Embedding provider: {type(embedding_provider).__name__}")
all_chunks: list[TextChunk] = []
if parse_result and parse_result.pages:
logger.info(f"[INDEX] PDF with {len(parse_result.pages)} pages, using line-based chunking with page metadata")
for page in parse_result.pages:
page_chunks = chunk_text_by_lines(
page.text,
min_line_length=10,
source=filename,
)
for pc in page_chunks:
pc.page = page.page
all_chunks.extend(page_chunks)
logger.info(f"[INDEX] Total chunks from PDF: {len(all_chunks)}")
else:
logger.info(f"[INDEX] Using line-based chunking")
all_chunks = chunk_text_by_lines(
text,
min_line_length=10,
source=filename,
)
logger.info(f"[INDEX] Total chunks: {len(all_chunks)}")
qdrant = await get_qdrant_client()
await qdrant.ensure_collection_exists(tenant_id)
points = []
total_chunks = len(all_chunks)
for i, chunk in enumerate(all_chunks):
embedding = await embedding_provider.embed(chunk.text)
payload = {
"text": chunk.text,
"source": doc_id,
"chunk_index": i,
"start_token": chunk.start_token,
"end_token": chunk.end_token,
}
if chunk.page is not None:
payload["page"] = chunk.page
if chunk.source:
payload["filename"] = chunk.source
points.append(
PointStruct(
id=str(uuid.uuid4()),
vector=embedding,
payload=payload,
)
)
progress = 20 + int((i + 1) / total_chunks * 70)
if i % 10 == 0 or i == total_chunks - 1:
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=progress
)
await session.commit()
if points:
logger.info(f"[INDEX] Upserting {len(points)} vectors to Qdrant...")
await qdrant.upsert_vectors(tenant_id, points)
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.COMPLETED.value, progress=100
)
await session.commit()
logger.info(
f"[INDEX] COMPLETED: tenant={tenant_id}, "
f"job_id={job_id}, chunks={len(all_chunks)}, text_len={len(text)}"
)
except Exception as e:
import traceback
logger.error(f"[INDEX] FAILED: {e}\n{traceback.format_exc()}")
await session.rollback()
async with async_session_maker() as error_session:
kb_service = KBService(error_session)
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.FAILED.value,
progress=0, error_msg=str(e)
)
await error_session.commit()
@router.get(
"/index/jobs/{job_id}",
operation_id="getIndexJob",
summary="Query index job status",
description="[AC-ASA-02] Get indexing job status and progress.",
responses={
200: {"description": "Job status details"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def get_index_job(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
job_id: str,
) -> JSONResponse:
"""
[AC-ASA-02] Get indexing job status with progress.
"""
logger.info(
f"[AC-ASA-02] Getting job status: tenant={tenant_id}, job_id={job_id}"
)
kb_service = KBService(session)
job = await kb_service.get_index_job(tenant_id, job_id)
if not job:
return JSONResponse(
status_code=404,
content={
"code": "JOB_NOT_FOUND",
"message": f"Job {job_id} not found",
},
)
return JSONResponse(
content={
"jobId": str(job.id),
"docId": str(job.doc_id),
"status": job.status,
"progress": job.progress,
"errorMsg": job.error_msg,
}
)
@router.delete(
"/documents/{doc_id}",
operation_id="deleteDocument",
summary="Delete document",
description="[AC-ASA-08] Delete a document and its associated files.",
responses={
200: {"description": "Document deleted"},
404: {"description": "Document not found"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def delete_document(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
doc_id: str,
) -> JSONResponse:
"""
[AC-ASA-08] Delete a document.
"""
logger.info(
f"[AC-ASA-08] Deleting document: tenant={tenant_id}, doc_id={doc_id}"
)
kb_service = KBService(session)
deleted = await kb_service.delete_document(tenant_id, doc_id)
if not deleted:
return JSONResponse(
status_code=404,
content={
"code": "DOCUMENT_NOT_FOUND",
"message": f"Document {doc_id} not found",
},
)
return JSONResponse(
content={
"success": True,
"message": "Document deleted",
}
)