ai-robot-core/ai-service/app/api/admin/kb.py

1412 lines
48 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Knowledge Base management endpoints.
[AC-ASA-01, AC-ASA-02, AC-ASA-08] Document upload, list, and index job status.
[AC-AISVC-59~AC-AISVC-64] Multi-knowledge-base management.
"""
import logging
import uuid
import json
import hashlib
from dataclasses import dataclass
from typing import Annotated, Any, Optional
import tiktoken
from fastapi import APIRouter, BackgroundTasks, Depends, File, Form, HTTPException, Query, UploadFile
from fastapi.responses import JSONResponse
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.config import get_settings
from app.core.database import get_session
from app.core.exceptions import MissingTenantIdException
from app.core.tenant import get_tenant_id
from app.models import ErrorResponse
from app.models.entities import (
Document,
DocumentStatus,
IndexJob,
IndexJobStatus,
KBType,
KnowledgeBase,
KnowledgeBaseCreate,
KnowledgeBaseUpdate,
)
from app.services.kb import KBService
from app.services.knowledge_base_service import KnowledgeBaseService
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/kb", tags=["KB Management"])
@dataclass
class TextChunk:
"""Text chunk with metadata."""
text: str
start_token: int
end_token: int
page: int | None = None
source: str | None = None
def chunk_text_by_lines(
text: str,
min_line_length: int = 10,
source: str | None = None,
) -> list[TextChunk]:
"""
按行分块,每行作为一个独立的检索单元。
Args:
text: 要分块的文本
min_line_length: 最小行长度,低于此长度的行会被跳过
source: 来源文件路径(可选)
Returns:
分块列表,每个块对应一行文本
"""
lines = text.split('\n')
chunks: list[TextChunk] = []
for i, line in enumerate(lines):
line = line.strip()
if len(line) < min_line_length:
continue
chunks.append(TextChunk(
text=line,
start_token=i,
end_token=i + 1,
page=None,
source=source,
))
return chunks
def chunk_text_with_tiktoken(
text: str,
chunk_size: int = 512,
overlap: int = 100,
page: int | None = None,
source: str | None = None,
) -> list[TextChunk]:
"""
使用 tiktoken 按 token 数分块,支持重叠分块。
Args:
text: 要分块的文本
chunk_size: 每个块的最大 token 数
overlap: 块之间的重叠 token 数
page: 页码(可选)
source: 来源文件路径(可选)
Returns:
分块列表,每个块包含文本及起始/结束位置
"""
encoding = tiktoken.get_encoding("cl100k_base")
tokens = encoding.encode(text)
chunks: list[TextChunk] = []
start = 0
while start < len(tokens):
end = min(start + chunk_size, len(tokens))
chunk_tokens = tokens[start:end]
chunk_text = encoding.decode(chunk_tokens)
chunks.append(TextChunk(
text=chunk_text,
start_token=start,
end_token=end,
page=page,
source=source,
))
if end == len(tokens):
break
start += chunk_size - overlap
return chunks
def get_current_tenant_id() -> str:
"""Dependency to get current tenant ID or raise exception."""
tenant_id = get_tenant_id()
if not tenant_id:
raise MissingTenantIdException()
return tenant_id
@router.get(
"/knowledge-bases",
operation_id="listKnowledgeBases",
summary="Query knowledge base list",
description="[AC-AISVC-60] Get list of knowledge bases for the current tenant with type and status filters.",
responses={
200: {"description": "Knowledge base list"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def list_knowledge_bases(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
kb_type: Annotated[Optional[str], Query()] = None,
is_enabled: Annotated[Optional[bool], Query()] = None,
) -> JSONResponse:
"""
[AC-AISVC-60] List all knowledge bases for the current tenant.
Supports filtering by kb_type and is_enabled status.
"""
try:
logger.info(f"[AC-AISVC-60] Listing knowledge bases: tenant={tenant_id}, kb_type={kb_type}, is_enabled={is_enabled}")
kb_service = KnowledgeBaseService(session)
logger.info(f"[AC-AISVC-60] KnowledgeBaseService created, calling list_knowledge_bases...")
knowledge_bases = await kb_service.list_knowledge_bases(
tenant_id=tenant_id,
kb_type=kb_type,
is_enabled=is_enabled,
)
logger.info(f"[AC-AISVC-60] Found {len(knowledge_bases)} knowledge bases")
data = []
for kb in knowledge_bases:
data.append({
"id": str(kb.id),
"name": kb.name,
"kbType": kb.kb_type,
"description": kb.description,
"priority": kb.priority,
"isEnabled": kb.is_enabled,
"docCount": kb.doc_count,
"createdAt": kb.created_at.isoformat() + "Z",
"updatedAt": kb.updated_at.isoformat() + "Z",
})
logger.info(f"[AC-AISVC-60] Returning {len(data)} knowledge bases")
return JSONResponse(content={"data": data})
except Exception as e:
import traceback
logger.error(f"[AC-AISVC-60] Error listing knowledge bases: {type(e).__name__}: {e}\n{traceback.format_exc()}")
raise
@router.post(
"/knowledge-bases",
operation_id="createKnowledgeBase",
summary="Create knowledge base",
description="[AC-AISVC-59] Create a new knowledge base with specified type and priority.",
responses={
201: {"description": "Knowledge base created"},
400: {"description": "Bad Request - invalid kb_type"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
status_code=201,
)
async def create_knowledge_base(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
kb_create: KnowledgeBaseCreate,
) -> JSONResponse:
"""
[AC-AISVC-59] Create a new knowledge base.
Initializes corresponding Qdrant Collection.
"""
valid_types = [t.value for t in KBType]
if kb_create.kb_type not in valid_types:
return JSONResponse(
status_code=400,
content={
"code": "INVALID_KB_TYPE",
"message": f"Invalid kb_type: {kb_create.kb_type}",
"details": {"valid_types": valid_types},
},
)
logger.info(
f"[AC-AISVC-59] Creating knowledge base: tenant={tenant_id}, "
f"name={kb_create.name}, type={kb_create.kb_type}"
)
kb_service = KnowledgeBaseService(session)
kb = await kb_service.create_knowledge_base(tenant_id, kb_create)
await session.commit()
return JSONResponse(
status_code=201,
content={
"id": str(kb.id),
"name": kb.name,
"kbType": kb.kb_type,
"description": kb.description,
"priority": kb.priority,
"isEnabled": kb.is_enabled,
"docCount": kb.doc_count,
"createdAt": kb.created_at.isoformat() + "Z",
"updatedAt": kb.updated_at.isoformat() + "Z",
},
)
@router.get(
"/knowledge-bases/{kb_id}",
operation_id="getKnowledgeBase",
summary="Get knowledge base details",
description="Get detailed information about a specific knowledge base.",
responses={
200: {"description": "Knowledge base details"},
404: {"description": "Knowledge base not found"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def get_knowledge_base(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
kb_id: str,
) -> JSONResponse:
"""
Get a specific knowledge base by ID.
"""
logger.info(f"Getting knowledge base: tenant={tenant_id}, kb_id={kb_id}")
kb_service = KnowledgeBaseService(session)
kb = await kb_service.get_knowledge_base(tenant_id, kb_id)
if not kb:
return JSONResponse(
status_code=404,
content={
"code": "KB_NOT_FOUND",
"message": f"Knowledge base {kb_id} not found",
},
)
return JSONResponse(
content={
"id": str(kb.id),
"name": kb.name,
"kbType": kb.kb_type,
"description": kb.description,
"priority": kb.priority,
"isEnabled": kb.is_enabled,
"docCount": kb.doc_count,
"createdAt": kb.created_at.isoformat() + "Z",
"updatedAt": kb.updated_at.isoformat() + "Z",
}
)
@router.put(
"/knowledge-bases/{kb_id}",
operation_id="updateKnowledgeBase",
summary="Update knowledge base",
description="[AC-AISVC-61] Update knowledge base name, type, description, priority, or enabled status.",
responses={
200: {"description": "Knowledge base updated"},
400: {"description": "Bad Request - invalid kb_type"},
404: {"description": "Knowledge base not found"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def update_knowledge_base(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
kb_id: str,
kb_update: KnowledgeBaseUpdate,
) -> JSONResponse:
"""
[AC-AISVC-61] Update a knowledge base.
"""
if kb_update.kb_type is not None:
valid_types = [t.value for t in KBType]
if kb_update.kb_type not in valid_types:
return JSONResponse(
status_code=400,
content={
"code": "INVALID_KB_TYPE",
"message": f"Invalid kb_type: {kb_update.kb_type}",
"details": {"valid_types": valid_types},
},
)
logger.info(
f"[AC-AISVC-61] Updating knowledge base: tenant={tenant_id}, kb_id={kb_id}"
)
kb_service = KnowledgeBaseService(session)
kb = await kb_service.update_knowledge_base(tenant_id, kb_id, kb_update)
if not kb:
return JSONResponse(
status_code=404,
content={
"code": "KB_NOT_FOUND",
"message": f"Knowledge base {kb_id} not found",
},
)
await session.commit()
return JSONResponse(
content={
"id": str(kb.id),
"name": kb.name,
"kbType": kb.kb_type,
"description": kb.description,
"priority": kb.priority,
"isEnabled": kb.is_enabled,
"docCount": kb.doc_count,
"createdAt": kb.created_at.isoformat() + "Z",
"updatedAt": kb.updated_at.isoformat() + "Z",
}
)
@router.delete(
"/knowledge-bases/{kb_id}",
operation_id="deleteKnowledgeBase",
summary="Delete knowledge base",
description="[AC-AISVC-62] Delete a knowledge base and its associated documents and Qdrant Collection.",
responses={
204: {"description": "Knowledge base deleted"},
404: {"description": "Knowledge base not found"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def delete_knowledge_base(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
kb_id: str,
) -> JSONResponse:
"""
[AC-AISVC-62] Delete a knowledge base.
Also deletes associated documents and Qdrant Collection.
"""
logger.info(
f"[AC-AISVC-62] Deleting knowledge base: tenant={tenant_id}, kb_id={kb_id}"
)
kb_service = KnowledgeBaseService(session)
deleted = await kb_service.delete_knowledge_base(tenant_id, kb_id)
if not deleted:
return JSONResponse(
status_code=404,
content={
"code": "KB_NOT_FOUND",
"message": f"Knowledge base {kb_id} not found",
},
)
await session.commit()
return JSONResponse(
status_code=204,
content=None,
)
@router.get(
"/documents",
operation_id="listDocuments",
summary="Query document list",
description="[AC-ASA-08] Get list of documents with pagination and filtering.",
responses={
200: {"description": "Document list with pagination"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def list_documents(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
kb_id: Annotated[Optional[str], Query()] = None,
status: Annotated[Optional[str], Query()] = None,
page: int = Query(1, ge=1),
page_size: int = Query(20, ge=1, le=100),
) -> JSONResponse:
"""
[AC-ASA-08] List documents with filtering and pagination.
"""
logger.info(
f"[AC-ASA-08] Listing documents: tenant={tenant_id}, kb_id={kb_id}, "
f"status={status}, page={page}, page_size={page_size}"
)
kb_service = KBService(session)
documents, total = await kb_service.list_documents(
tenant_id=tenant_id,
kb_id=kb_id,
status=status,
page=page,
page_size=page_size,
)
total_pages = (total + page_size - 1) // page_size if total > 0 else 0
data = []
for doc in documents:
job_stmt = select(IndexJob).where(
IndexJob.tenant_id == tenant_id,
IndexJob.doc_id == doc.id,
).order_by(IndexJob.created_at.desc())
job_result = await session.execute(job_stmt)
latest_job = job_result.scalar_one_or_none()
data.append({
"docId": str(doc.id),
"kbId": doc.kb_id,
"fileName": doc.file_name,
"status": doc.status,
"metadata": doc.doc_metadata,
"jobId": str(latest_job.id) if latest_job else None,
"createdAt": doc.created_at.isoformat() + "Z",
"updatedAt": doc.updated_at.isoformat() + "Z",
})
return JSONResponse(
content={
"data": data,
"pagination": {
"page": page,
"pageSize": page_size,
"total": total,
"totalPages": total_pages,
},
}
)
@router.post(
"/documents",
operation_id="uploadDocument",
summary="Upload/import document",
description="[AC-ASA-01, AC-AISVC-63, AC-IDSMETA-15] Upload document to specified knowledge base and trigger indexing job.",
responses={
202: {"description": "Accepted - async indexing job started"},
400: {"description": "Bad Request - unsupported format or invalid kb_id"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def upload_document(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
kb_id: str = Form(...),
metadata: str = Form(default="{}", description="元数据 JSON 字符串,根据元数据模式配置动态字段"),
) -> JSONResponse:
"""
[AC-ASA-01, AC-AISVC-63, AC-IDSMETA-15] Upload document to specified knowledge base.
Creates KB if not exists, indexes to corresponding Qdrant Collection.
[AC-IDSMETA-15] 支持动态元数据校验:
- metadata: JSON 格式的元数据,字段根据元数据模式配置
- 根据 scope=kb_document 的字段定义进行 required/type/enum 校验
示例 metadata:
- 教育行业: {"grade": "初一", "subject": "语文", "type": "痛点"}
- 医疗行业: {"department": "内科", "disease_type": "慢性病", "content_type": "科普"}
"""
import json
from pathlib import Path
from app.services.document import get_supported_document_formats
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
logger.info(
f"[AC-IDSMETA-15] Uploading document: tenant={tenant_id}, "
f"kb_id={kb_id}, filename={file.filename}"
)
file_ext = Path(file.filename or "").suffix.lower()
supported_formats = get_supported_document_formats()
if file_ext and file_ext not in supported_formats:
return JSONResponse(
status_code=400,
content={
"code": "UNSUPPORTED_FORMAT",
"message": f"Unsupported file format: {file_ext}",
"details": {
"supported_formats": supported_formats,
},
},
)
try:
metadata_dict = json.loads(metadata) if metadata else {}
except json.JSONDecodeError:
return JSONResponse(
status_code=400,
content={
"code": "INVALID_METADATA",
"message": "Invalid JSON format for metadata",
},
)
field_def_service = MetadataFieldDefinitionService(session)
is_valid, validation_errors = await field_def_service.validate_metadata_for_create(
tenant_id, metadata_dict, "kb_document"
)
if not is_valid:
logger.warning(f"[AC-IDSMETA-15] Metadata validation failed: {validation_errors}")
return JSONResponse(
status_code=400,
content={
"code": "METADATA_VALIDATION_ERROR",
"message": "Metadata validation failed",
"details": {
"errors": validation_errors,
},
},
)
kb_service = KnowledgeBaseService(session)
try:
kb = await kb_service.get_knowledge_base(tenant_id, kb_id)
if not kb:
kb = await kb_service.get_or_create_default_kb(tenant_id)
kb_id = str(kb.id)
logger.info(f"[AC-IDSMETA-15] KB not found, using default: {kb_id}")
else:
kb_id = str(kb.id)
except Exception:
kb = await kb_service.get_or_create_default_kb(tenant_id)
kb_id = str(kb.id)
doc_kb_service = KBService(session)
file_content = await file.read()
document, job = await doc_kb_service.upload_document(
tenant_id=tenant_id,
kb_id=kb_id,
file_name=file.filename or "unknown",
file_content=file_content,
file_type=file.content_type,
metadata=metadata_dict,
)
await kb_service.update_doc_count(tenant_id, kb_id, delta=1)
await session.commit()
background_tasks.add_task(
_index_document, tenant_id, kb_id, str(job.id), str(document.id), file_content, file.filename, metadata_dict
)
return JSONResponse(
status_code=202,
content={
"jobId": str(job.id),
"docId": str(document.id),
"kbId": kb_id,
"status": job.status,
"metadata": metadata_dict,
},
)
async def _index_document(
tenant_id: str,
kb_id: str,
job_id: str,
doc_id: str,
content: bytes,
filename: str | None = None,
metadata: dict[str, Any] | None = None,
):
"""
Background indexing task.
[AC-AISVC-33, AC-AISVC-34, AC-AISVC-35, AC-AISVC-63] Uses document parsing and pluggable embedding.
Indexes to the specified knowledge base's Qdrant Collection.
Args:
metadata: 动态元数据,字段根据元数据模式配置
"""
import asyncio
import tempfile
from pathlib import Path
from qdrant_client.models import PointStruct
from app.core.database import async_session_maker
from app.core.qdrant_client import get_qdrant_client
from app.services.document import DocumentParseException, UnsupportedFormatError, parse_document
from app.services.embedding import get_embedding_provider
from app.services.kb import KBService
logger.info(f"[INDEX] Starting indexing: tenant={tenant_id}, kb_id={kb_id}, job_id={job_id}, doc_id={doc_id}, filename={filename}")
await asyncio.sleep(1)
async with async_session_maker() as session:
kb_service = KBService(session)
try:
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=10
)
await session.commit()
parse_result = None
text = None
file_ext = Path(filename or "").suffix.lower()
logger.info(f"[INDEX] File extension: {file_ext}, content size: {len(content)} bytes")
text_extensions = {".txt", ".md", ".markdown", ".rst", ".log", ".json", ".xml", ".yaml", ".yml"}
if file_ext in text_extensions or not file_ext:
logger.info("[INDEX] Treating as text file, trying multiple encodings")
text = None
for encoding in ["utf-8", "gbk", "gb2312", "gb18030", "big5", "utf-16", "latin-1"]:
try:
text = content.decode(encoding)
logger.info(f"[INDEX] Successfully decoded with encoding: {encoding}")
break
except (UnicodeDecodeError, LookupError):
continue
if text is None:
text = content.decode("utf-8", errors="replace")
logger.warning("[INDEX] Failed to decode with known encodings, using utf-8 with replacement")
else:
logger.info("[INDEX] Binary file detected, will parse with document parser")
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=15
)
await session.commit()
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file:
tmp_file.write(content)
tmp_path = tmp_file.name
logger.info(f"[INDEX] Temp file created: {tmp_path}")
try:
logger.info(f"[INDEX] Starting document parsing for {file_ext}...")
parse_result = parse_document(tmp_path)
text = parse_result.text
logger.info(
f"[INDEX] Parsed document SUCCESS: {filename}, "
f"chars={len(text)}, format={parse_result.metadata.get('format')}, "
f"pages={len(parse_result.pages) if parse_result.pages else 'N/A'}, "
f"metadata={parse_result.metadata}"
)
if len(text) < 100:
logger.warning(f"[INDEX] Parsed text is very short, preview: {text[:200]}")
except UnsupportedFormatError as e:
logger.error(f"[INDEX] UnsupportedFormatError: {e}")
text = content.decode("utf-8", errors="ignore")
except DocumentParseException as e:
logger.error(f"[INDEX] DocumentParseException: {e}, details={getattr(e, 'details', {})}")
text = content.decode("utf-8", errors="ignore")
except Exception as e:
logger.error(f"[INDEX] Unexpected parsing error: {type(e).__name__}: {e}")
text = content.decode("utf-8", errors="ignore")
finally:
Path(tmp_path).unlink(missing_ok=True)
logger.info("[INDEX] Temp file cleaned up")
logger.info(f"[INDEX] Final text length: {len(text)} chars")
if len(text) < 50:
logger.warning(f"[INDEX] Text too short, preview: {repr(text[:200])}")
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=20
)
await session.commit()
logger.info("[INDEX] Getting embedding provider...")
embedding_provider = await get_embedding_provider()
logger.info(f"[INDEX] Embedding provider: {type(embedding_provider).__name__}")
all_chunks: list[TextChunk] = []
if parse_result and parse_result.pages:
logger.info(f"[INDEX] PDF with {len(parse_result.pages)} pages, using line-based chunking with page metadata")
for page in parse_result.pages:
page_chunks = chunk_text_by_lines(
page.text,
min_line_length=10,
source=filename,
)
for pc in page_chunks:
pc.page = page.page
all_chunks.extend(page_chunks)
logger.info(f"[INDEX] Total chunks from PDF: {len(all_chunks)}")
else:
logger.info("[INDEX] Using line-based chunking")
all_chunks = chunk_text_by_lines(
text,
min_line_length=10,
source=filename,
)
logger.info(f"[INDEX] Total chunks: {len(all_chunks)}")
qdrant = await get_qdrant_client()
await qdrant.ensure_kb_collection_exists(tenant_id, kb_id, use_multi_vector=True)
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
use_multi_vector = isinstance(embedding_provider, NomicEmbeddingProvider)
logger.info(f"[INDEX] Using multi-vector format: {use_multi_vector}")
points = []
total_chunks = len(all_chunks)
doc_metadata = metadata or {}
logger.info(f"[INDEX] Document metadata: {doc_metadata}")
for i, chunk in enumerate(all_chunks):
payload = {
"text": chunk.text,
"source": doc_id,
"kb_id": kb_id,
"chunk_index": i,
"start_token": chunk.start_token,
"end_token": chunk.end_token,
"metadata": doc_metadata,
}
if chunk.page is not None:
payload["page"] = chunk.page
if chunk.source:
payload["filename"] = chunk.source
if use_multi_vector:
embedding_result = await embedding_provider.embed_document(chunk.text)
points.append({
"id": str(uuid.uuid4()),
"vector": {
"full": embedding_result.embedding_full,
"dim_256": embedding_result.embedding_256,
"dim_512": embedding_result.embedding_512,
},
"payload": payload,
})
else:
embedding = await embedding_provider.embed(chunk.text)
points.append(
PointStruct(
id=str(uuid.uuid4()),
vector=embedding,
payload=payload,
)
)
progress = 20 + int((i + 1) / total_chunks * 70)
if i % 10 == 0 or i == total_chunks - 1:
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=progress
)
await session.commit()
if points:
logger.info(f"[INDEX] Upserting {len(points)} vectors to Qdrant for kb_id={kb_id}...")
if use_multi_vector:
await qdrant.upsert_multi_vector(tenant_id, points, kb_id=kb_id)
else:
await qdrant.upsert_vectors(tenant_id, points, kb_id=kb_id)
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.COMPLETED.value, progress=100
)
await session.commit()
logger.info(
f"[INDEX] COMPLETED: tenant={tenant_id}, kb_id={kb_id}, "
f"job_id={job_id}, chunks={len(all_chunks)}, text_len={len(text)}"
)
except Exception as e:
import traceback
logger.error(f"[INDEX] FAILED: {e}\n{traceback.format_exc()}")
await session.rollback()
async with async_session_maker() as error_session:
kb_service = KBService(error_session)
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.FAILED.value,
progress=0, error_msg=str(e)
)
await error_session.commit()
@router.get(
"/index/jobs/{job_id}",
operation_id="getIndexJob",
summary="Query index job status",
description="[AC-ASA-02] Get indexing job status and progress.",
responses={
200: {"description": "Job status details"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def get_index_job(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
job_id: str,
) -> JSONResponse:
"""
[AC-ASA-02] Get indexing job status with progress.
"""
logger.info(
f"[AC-ASA-02] Getting job status: tenant={tenant_id}, job_id={job_id}"
)
kb_service = KBService(session)
job = await kb_service.get_index_job(tenant_id, job_id)
if not job:
return JSONResponse(
status_code=404,
content={
"code": "JOB_NOT_FOUND",
"message": f"Job {job_id} not found",
},
)
return JSONResponse(
content={
"jobId": str(job.id),
"docId": str(job.doc_id),
"status": job.status,
"progress": job.progress,
"errorMsg": job.error_msg,
}
)
@router.delete(
"/documents/{doc_id}",
operation_id="deleteDocument",
summary="Delete document",
description="[AC-ASA-08] Delete a document and its associated files.",
responses={
200: {"description": "Document deleted"},
404: {"description": "Document not found"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def delete_document(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
doc_id: str,
) -> JSONResponse:
"""
[AC-ASA-08] Delete a document.
"""
logger.info(
f"[AC-ASA-08] Deleting document: tenant={tenant_id}, doc_id={doc_id}"
)
kb_service = KBService(session)
deleted = await kb_service.delete_document(tenant_id, doc_id)
if not deleted:
return JSONResponse(
status_code=404,
content={
"code": "DOCUMENT_NOT_FOUND",
"message": f"Document {doc_id} not found",
},
)
return JSONResponse(
content={
"success": True,
"message": "Document deleted",
}
)
@router.put(
"/documents/{doc_id}/metadata",
operation_id="updateDocumentMetadata",
summary="Update document metadata",
description="[AC-ASA-08] Update metadata for a specific document.",
responses={
200: {"description": "Metadata updated"},
404: {"description": "Document not found"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def update_document_metadata(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
doc_id: str,
body: dict,
) -> JSONResponse:
"""
[AC-ASA-08] Update document metadata.
"""
import json
metadata = body.get("metadata")
if metadata is not None and not isinstance(metadata, dict):
try:
metadata = json.loads(metadata) if isinstance(metadata, str) else metadata
except json.JSONDecodeError:
return JSONResponse(
status_code=400,
content={
"code": "INVALID_METADATA",
"message": "Invalid JSON format for metadata",
},
)
logger.info(
f"[AC-ASA-08] Updating document metadata: tenant={tenant_id}, doc_id={doc_id}"
)
from sqlalchemy import select
from app.models.entities import Document
stmt = select(Document).where(
Document.tenant_id == tenant_id,
Document.id == doc_id,
)
result = await session.execute(stmt)
document = result.scalar_one_or_none()
if not document:
return JSONResponse(
status_code=404,
content={
"code": "DOCUMENT_NOT_FOUND",
"message": f"Document {doc_id} not found",
},
)
document.doc_metadata = metadata
await session.commit()
return JSONResponse(
content={
"success": True,
"message": "Metadata updated",
"metadata": document.doc_metadata,
}
)
@router.post(
"/documents/batch-upload",
operation_id="batchUploadDocuments",
summary="Batch upload documents from zip",
description="Upload a zip file containing multiple folders, each with a markdown file and metadata.json",
responses={
200: {"description": "Batch upload result"},
400: {"description": "Bad Request - invalid zip or missing files"},
401: {"description": "Unauthorized", "model": ErrorResponse},
403: {"description": "Forbidden", "model": ErrorResponse},
},
)
async def batch_upload_documents(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
background_tasks: BackgroundTasks,
file: UploadFile = File(...),
kb_id: str = Form(...),
) -> JSONResponse:
"""
Batch upload documents from a zip file.
Zip structure:
- Each folder contains one .md file and one metadata.json
- metadata.json uses field_key from MetadataFieldDefinition as keys
Example metadata.json:
{
"grade": "高一",
"subject": "数学",
"type": "痛点"
}
"""
import json
import tempfile
import zipfile
from pathlib import Path
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
logger.info(
f"[BATCH-UPLOAD] Starting batch upload: tenant={tenant_id}, "
f"kb_id={kb_id}, filename={file.filename}"
)
if not file.filename or not file.filename.lower().endswith('.zip'):
return JSONResponse(
status_code=400,
content={
"code": "INVALID_FORMAT",
"message": "Only .zip files are supported",
},
)
kb_service = KnowledgeBaseService(session)
kb = await kb_service.get_knowledge_base(tenant_id, kb_id)
if not kb:
return JSONResponse(
status_code=404,
content={
"code": "KB_NOT_FOUND",
"message": f"Knowledge base {kb_id} not found",
},
)
file_content = await file.read()
results = []
succeeded = 0
failed = 0
with tempfile.TemporaryDirectory() as temp_dir:
zip_path = Path(temp_dir) / "upload.zip"
with open(zip_path, "wb") as f:
f.write(file_content)
try:
with zipfile.ZipFile(zip_path, 'r') as zf:
zf.extractall(temp_dir)
except zipfile.BadZipFile as e:
return JSONResponse(
status_code=400,
content={
"code": "INVALID_ZIP",
"message": f"Invalid zip file: {str(e)}",
},
)
extracted_path = Path(temp_dir)
# 列出解压后的所有内容,用于调试
all_items = list(extracted_path.iterdir())
logger.info(f"[BATCH-UPLOAD] Extracted items: {[item.name for item in all_items]}")
# 递归查找所有包含 content.txt/md 和 metadata.json 的文件夹
def find_document_folders(path: Path) -> list[Path]:
"""递归查找所有包含文档文件的文件夹"""
doc_folders = []
# 检查当前文件夹是否包含文档文件
content_files = (
list(path.glob("*.md")) +
list(path.glob("*.markdown")) +
list(path.glob("*.txt"))
)
if content_files:
# 这个文件夹包含文档文件,是一个文档文件夹
doc_folders.append(path)
logger.info(f"[BATCH-UPLOAD] Found document folder: {path.name}, files: {[f.name for f in content_files]}")
# 递归检查子文件夹
for subfolder in [p for p in path.iterdir() if p.is_dir()]:
doc_folders.extend(find_document_folders(subfolder))
return doc_folders
folders = find_document_folders(extracted_path)
if not folders:
logger.error(f"[BATCH-UPLOAD] No document folders found in zip. Items found: {[item.name for item in all_items]}")
return JSONResponse(
status_code=400,
content={
"code": "NO_DOCUMENTS_FOUND",
"message": "压缩包中没有找到包含 .txt/.md 文件的文件夹",
"details": {
"expected_structure": "每个文件夹应包含 content.txt (或 .md) 和 metadata.json",
"found_items": [item.name for item in all_items],
},
},
)
logger.info(f"[BATCH-UPLOAD] Found {len(folders)} document folders")
for folder in folders:
folder_name = folder.name if folder != extracted_path else "root"
content_files = (
list(folder.glob("*.md")) +
list(folder.glob("*.markdown")) +
list(folder.glob("*.txt"))
)
if not content_files:
# 这种情况不应该发生,因为我们已经过滤过了
failed += 1
results.append({
"folder": folder_name,
"status": "failed",
"error": "No content file found",
})
continue
content_file = content_files[0]
metadata_file = folder / "metadata.json"
metadata_dict = {}
if metadata_file.exists():
try:
with open(metadata_file, 'r', encoding='utf-8') as f:
metadata_dict = json.load(f)
except json.JSONDecodeError as e:
failed += 1
results.append({
"folder": folder_name,
"status": "failed",
"error": f"Invalid metadata.json: {str(e)}",
})
continue
else:
logger.warning(f"[BATCH-UPLOAD] No metadata.json in folder {folder_name}, using empty metadata")
field_def_service = MetadataFieldDefinitionService(session)
is_valid, validation_errors = await field_def_service.validate_metadata_for_create(
tenant_id, metadata_dict, "kb_document"
)
if not is_valid:
failed += 1
results.append({
"folder": folder_name,
"status": "failed",
"error": f"Metadata validation failed: {validation_errors}",
})
continue
try:
with open(content_file, 'rb') as f:
doc_content = f.read()
file_ext = content_file.suffix.lower()
if file_ext == '.txt':
file_type = "text/plain"
else:
file_type = "text/markdown"
doc_kb_service = KBService(session)
document, job = await doc_kb_service.upload_document(
tenant_id=tenant_id,
kb_id=kb_id,
file_name=content_file.name,
file_content=doc_content,
file_type=file_type,
metadata=metadata_dict,
)
await kb_service.update_doc_count(tenant_id, kb_id, delta=1)
await session.commit()
background_tasks.add_task(
_index_document,
tenant_id,
kb_id,
str(job.id),
str(document.id),
doc_content,
content_file.name,
metadata_dict,
)
succeeded += 1
results.append({
"folder": folder_name,
"docId": str(document.id),
"jobId": str(job.id),
"status": "created",
"fileName": content_file.name,
})
logger.info(
f"[BATCH-UPLOAD] Created document: folder={folder_name}, "
f"doc_id={document.id}, job_id={job.id}"
)
except Exception as e:
failed += 1
results.append({
"folder": folder_name,
"status": "failed",
"error": str(e),
})
logger.error(f"[BATCH-UPLOAD] Failed to create document: folder={folder_name}, error={e}")
logger.info(
f"[BATCH-UPLOAD] Completed: total={len(results)}, succeeded={succeeded}, failed={failed}"
)
return JSONResponse(
content={
"success": True,
"total": len(results),
"succeeded": succeeded,
"failed": failed,
"results": results,
}
)
@router.post(
"/{kb_id}/documents/json-batch",
summary="[AC-KB-03] JSON批量上传文档",
description="上传JSONL格式文件每行一个JSON对象包含text和元数据字段",
)
async def upload_json_batch(
kb_id: str,
tenant_id: str = Query(..., description="租户ID"),
file: UploadFile = File(..., description="JSONL格式文件每行一个JSON对象"),
session: AsyncSession = Depends(get_session),
background_tasks: BackgroundTasks = None,
):
"""
JSON批量上传文档
文件格式JSONL (每行一个JSON对象)
必填字段text - 需要录入知识库的文本内容
可选字段元数据字段如grade, subject, kb_scene等
示例:
{"text": "课程内容...", "grade": "初二", "subject": "数学", "kb_scene": "课程咨询"}
{"text": "另一条课程内容...", "grade": "初三", "info_type": "课程概述"}
"""
kb = await session.get(KnowledgeBase, kb_id)
if not kb:
raise HTTPException(status_code=404, detail="知识库不存在")
if kb.tenant_id != tenant_id:
raise HTTPException(status_code=403, detail="无权访问此知识库")
valid_field_keys = set()
try:
field_defs = await MetadataFieldDefinitionService(session).get_fields(
tenant_id=tenant_id,
include_inactive=False,
)
valid_field_keys = {f.field_key for f in field_defs}
logger.info(f"[AC-KB-03] Valid metadata fields for tenant {tenant_id}: {valid_field_keys}")
except Exception as e:
logger.warning(f"[AC-KB-03] Failed to get metadata fields: {e}")
content = await file.read()
try:
text_content = content.decode("utf-8")
except UnicodeDecodeError:
try:
text_content = content.decode("gbk")
except UnicodeDecodeError:
raise HTTPException(status_code=400, detail="文件编码不支持请使用UTF-8编码")
lines = text_content.strip().split("\n")
if not lines:
raise HTTPException(status_code=400, detail="文件内容为空")
results = []
succeeded = 0
failed = 0
kb_service = KBService(session)
for line_num, line in enumerate(lines, 1):
line = line.strip()
if not line:
continue
try:
json_obj = json.loads(line)
except json.JSONDecodeError as e:
failed += 1
results.append({
"line": line_num,
"success": False,
"error": f"JSON解析失败: {e}",
})
continue
text = json_obj.get("text")
if not text:
failed += 1
results.append({
"line": line_num,
"success": False,
"error": "缺少必填字段: text",
})
continue
metadata = {}
for key, value in json_obj.items():
if key == "text":
continue
if valid_field_keys and key not in valid_field_keys:
logger.debug(f"[AC-KB-03] Skipping invalid metadata field: {key}")
continue
if value is not None:
metadata[key] = value
try:
file_name = f"json_batch_line_{line_num}.txt"
file_content = text.encode("utf-8")
document, job = await kb_service.upload_document(
tenant_id=tenant_id,
kb_id=kb_id,
file_name=file_name,
file_content=file_content,
file_type="text/plain",
metadata=metadata,
)
if background_tasks:
background_tasks.add_task(
_index_document,
tenant_id,
kb_id,
str(job.id),
str(document.id),
file_content,
file_name,
metadata,
)
succeeded += 1
results.append({
"line": line_num,
"success": True,
"doc_id": str(document.id),
"job_id": str(job.id),
"metadata": metadata,
})
except Exception as e:
failed += 1
results.append({
"line": line_num,
"success": False,
"error": str(e),
})
logger.error(f"[AC-KB-03] Failed to upload document at line {line_num}: {e}")
await session.commit()
logger.info(f"[AC-KB-03] JSON batch upload completed: kb_id={kb_id}, total={len(lines)}, succeeded={succeeded}, failed={failed}")
return JSONResponse(
content={
"success": True,
"total": len(lines),
"succeeded": succeeded,
"failed": failed,
"valid_metadata_fields": list(valid_field_keys) if valid_field_keys else [],
"results": results,
}
)