feat(ASA-P5): PDF智能分块处理,使用tiktoken按token分块并保留页码元数据 [AC-ASA-01]

This commit is contained in:
MerCry 2026-02-25 01:16:59 +08:00
parent e9fee2f80e
commit 559d8c0c53
4 changed files with 154 additions and 22 deletions

View File

@ -6,8 +6,10 @@ Knowledge Base management endpoints.
import logging import logging
import os import os
import uuid import uuid
from dataclasses import dataclass
from typing import Annotated, Optional from typing import Annotated, Optional
import tiktoken
from fastapi import APIRouter, BackgroundTasks, Depends, Query, UploadFile, File, Form from fastapi import APIRouter, BackgroundTasks, Depends, Query, UploadFile, File, Form
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from sqlalchemy import select from sqlalchemy import select
@ -25,6 +27,59 @@ logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/kb", tags=["KB Management"]) router = APIRouter(prefix="/admin/kb", tags=["KB Management"])
@dataclass
class TextChunk:
"""Text chunk with metadata."""
text: str
start_token: int
end_token: int
page: int | None = None
source: str | None = None
def chunk_text_with_tiktoken(
text: str,
chunk_size: int = 512,
overlap: int = 100,
page: int | None = None,
source: str | None = None,
) -> list[TextChunk]:
"""
使用 tiktoken token 数分块支持重叠分块
Args:
text: 要分块的文本
chunk_size: 每个块的最大 token
overlap: 块之间的重叠 token
page: 页码可选
source: 来源文件路径可选
Returns:
分块列表每个块包含文本及起始/结束位置
"""
encoding = tiktoken.get_encoding("cl100k_base")
tokens = encoding.encode(text)
chunks: list[TextChunk] = []
start = 0
while start < len(tokens):
end = min(start + chunk_size, len(tokens))
chunk_tokens = tokens[start:end]
chunk_text = encoding.decode(chunk_tokens)
chunks.append(TextChunk(
text=chunk_text,
start_token=start,
end_token=end,
page=page,
source=source,
))
if end == len(tokens):
break
start += chunk_size - overlap
return chunks
def get_current_tenant_id() -> str: def get_current_tenant_id() -> str:
"""Dependency to get current tenant ID or raise exception.""" """Dependency to get current tenant ID or raise exception."""
tenant_id = get_tenant_id() tenant_id = get_tenant_id()
@ -238,12 +293,13 @@ async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: byt
from app.services.kb import KBService from app.services.kb import KBService
from app.core.qdrant_client import get_qdrant_client from app.core.qdrant_client import get_qdrant_client
from app.services.embedding import get_embedding_provider from app.services.embedding import get_embedding_provider
from app.services.document import parse_document, UnsupportedFormatError, DocumentParseException from app.services.document import parse_document, UnsupportedFormatError, DocumentParseException, PageText
from qdrant_client.models import PointStruct from qdrant_client.models import PointStruct
import asyncio import asyncio
import tempfile import tempfile
from pathlib import Path from pathlib import Path
logger.info(f"[INDEX] Starting indexing: tenant={tenant_id}, job_id={job_id}, doc_id={doc_id}, filename={filename}")
await asyncio.sleep(1) await asyncio.sleep(1)
async with async_session_maker() as session: async with async_session_maker() as session:
@ -254,14 +310,18 @@ async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: byt
) )
await session.commit() await session.commit()
parse_result = None
text = None text = None
file_ext = Path(filename or "").suffix.lower() file_ext = Path(filename or "").suffix.lower()
logger.info(f"[INDEX] File extension: {file_ext}, content size: {len(content)} bytes")
text_extensions = {".txt", ".md", ".markdown", ".rst", ".log", ".json", ".xml", ".yaml", ".yml"} text_extensions = {".txt", ".md", ".markdown", ".rst", ".log", ".json", ".xml", ".yaml", ".yml"}
if file_ext in text_extensions or not file_ext: if file_ext in text_extensions or not file_ext:
logger.info(f"[INDEX] Treating as text file, decoding with UTF-8")
text = content.decode("utf-8", errors="ignore") text = content.decode("utf-8", errors="ignore")
else: else:
logger.info(f"[INDEX] Binary file detected, will parse with document parser")
await kb_service.update_job_status( await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=15 tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=15
) )
@ -271,45 +331,95 @@ async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: byt
tmp_file.write(content) tmp_file.write(content)
tmp_path = tmp_file.name tmp_path = tmp_file.name
logger.info(f"[INDEX] Temp file created: {tmp_path}")
try: try:
logger.info(f"[INDEX] Starting document parsing for {file_ext}...")
parse_result = parse_document(tmp_path) parse_result = parse_document(tmp_path)
text = parse_result.text text = parse_result.text
logger.info( logger.info(
f"[AC-AISVC-33] Parsed document: {filename}, " f"[INDEX] Parsed document SUCCESS: {filename}, "
f"chars={len(text)}, format={parse_result.metadata.get('format')}" f"chars={len(text)}, format={parse_result.metadata.get('format')}, "
f"pages={len(parse_result.pages) if parse_result.pages else 'N/A'}, "
f"metadata={parse_result.metadata}"
) )
except (UnsupportedFormatError, DocumentParseException) as e: if len(text) < 100:
logger.warning(f"Failed to parse document {filename}: {e}, falling back to text decode") logger.warning(f"[INDEX] Parsed text is very short, preview: {text[:200]}")
except UnsupportedFormatError as e:
logger.error(f"[INDEX] UnsupportedFormatError: {e}")
text = content.decode("utf-8", errors="ignore")
except DocumentParseException as e:
logger.error(f"[INDEX] DocumentParseException: {e}, details={getattr(e, 'details', {})}")
text = content.decode("utf-8", errors="ignore")
except Exception as e:
logger.error(f"[INDEX] Unexpected parsing error: {type(e).__name__}: {e}")
text = content.decode("utf-8", errors="ignore") text = content.decode("utf-8", errors="ignore")
finally: finally:
Path(tmp_path).unlink(missing_ok=True) Path(tmp_path).unlink(missing_ok=True)
logger.info(f"[INDEX] Temp file cleaned up")
logger.info(f"[INDEX] Final text length: {len(text)} chars")
if len(text) < 50:
logger.warning(f"[INDEX] Text too short, preview: {repr(text[:200])}")
await kb_service.update_job_status( await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=20 tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=20
) )
await session.commit() await session.commit()
logger.info(f"[INDEX] Getting embedding provider...")
embedding_provider = await get_embedding_provider() embedding_provider = await get_embedding_provider()
logger.info(f"[INDEX] Embedding provider: {type(embedding_provider).__name__}")
chunks = [text[i:i+500] for i in range(0, len(text), 500)] all_chunks: list[TextChunk] = []
if parse_result and parse_result.pages:
logger.info(f"[INDEX] PDF with {len(parse_result.pages)} pages, using tiktoken chunking with page metadata")
for page in parse_result.pages:
page_chunks = chunk_text_with_tiktoken(
page.text,
chunk_size=512,
overlap=100,
page=page.page,
source=filename,
)
all_chunks.extend(page_chunks)
logger.info(f"[INDEX] Total chunks from PDF: {len(all_chunks)}")
else:
logger.info(f"[INDEX] Using tiktoken chunking without page metadata")
all_chunks = chunk_text_with_tiktoken(
text,
chunk_size=512,
overlap=100,
source=filename,
)
logger.info(f"[INDEX] Total chunks: {len(all_chunks)}")
qdrant = await get_qdrant_client() qdrant = await get_qdrant_client()
await qdrant.ensure_collection_exists(tenant_id) await qdrant.ensure_collection_exists(tenant_id)
points = [] points = []
total_chunks = len(chunks) total_chunks = len(all_chunks)
for i, chunk in enumerate(chunks): for i, chunk in enumerate(all_chunks):
embedding = await embedding_provider.embed(chunk) embedding = await embedding_provider.embed(chunk.text)
payload = {
"text": chunk.text,
"source": doc_id,
"chunk_index": i,
"start_token": chunk.start_token,
"end_token": chunk.end_token,
}
if chunk.page is not None:
payload["page"] = chunk.page
if chunk.source:
payload["filename"] = chunk.source
points.append( points.append(
PointStruct( PointStruct(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
vector=embedding, vector=embedding,
payload={ payload=payload,
"text": chunk,
"source": doc_id,
"chunk_index": i,
},
) )
) )
@ -321,6 +431,7 @@ async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: byt
await session.commit() await session.commit()
if points: if points:
logger.info(f"[INDEX] Upserting {len(points)} vectors to Qdrant...")
await qdrant.upsert_vectors(tenant_id, points) await qdrant.upsert_vectors(tenant_id, points)
await kb_service.update_job_status( await kb_service.update_job_status(
@ -329,12 +440,13 @@ async def _index_document(tenant_id: str, job_id: str, doc_id: str, content: byt
await session.commit() await session.commit()
logger.info( logger.info(
f"[AC-ASA-01] Indexing completed: tenant={tenant_id}, " f"[INDEX] COMPLETED: tenant={tenant_id}, "
f"job_id={job_id}, chunks={len(chunks)}" f"job_id={job_id}, chunks={len(all_chunks)}, text_len={len(text)}"
) )
except Exception as e: except Exception as e:
logger.error(f"[AC-ASA-01] Indexing failed: {e}") import traceback
logger.error(f"[INDEX] FAILED: {e}\n{traceback.format_exc()}")
await session.rollback() await session.rollback()
async with async_session_maker() as error_session: async with async_session_maker() as error_session:
kb_service = KBService(error_session) kb_service = KBService(error_session)

View File

@ -6,6 +6,7 @@ Document parsing services package.
from app.services.document.base import ( from app.services.document.base import (
DocumentParseException, DocumentParseException,
DocumentParser, DocumentParser,
PageText,
ParseResult, ParseResult,
UnsupportedFormatError, UnsupportedFormatError,
) )
@ -22,6 +23,7 @@ from app.services.document.word_parser import WordParser
__all__ = [ __all__ = [
"DocumentParseException", "DocumentParseException",
"DocumentParser", "DocumentParser",
"PageText",
"ParseResult", "ParseResult",
"UnsupportedFormatError", "UnsupportedFormatError",
"DocumentParserFactory", "DocumentParserFactory",

View File

@ -13,6 +13,15 @@ from pathlib import Path
from typing import Any from typing import Any
@dataclass
class PageText:
"""
Text content from a single page.
"""
page: int
text: str
@dataclass @dataclass
class ParseResult: class ParseResult:
""" """
@ -24,6 +33,7 @@ class ParseResult:
file_size: int file_size: int
page_count: int | None = None page_count: int | None = None
metadata: dict[str, Any] = field(default_factory=dict) metadata: dict[str, Any] = field(default_factory=dict)
pages: list[PageText] = field(default_factory=list)
class DocumentParser(ABC): class DocumentParser(ABC):

View File

@ -12,6 +12,7 @@ from typing import Any
from app.services.document.base import ( from app.services.document.base import (
DocumentParseException, DocumentParseException,
DocumentParser, DocumentParser,
PageText,
ParseResult, ParseResult,
) )
@ -68,13 +69,15 @@ class PDFParser(DocumentParser):
try: try:
doc = fitz.open(path) doc = fitz.open(path)
pages: list[PageText] = []
text_parts = [] text_parts = []
page_count = len(doc) page_count = len(doc)
for page_num in range(page_count): for page_num in range(page_count):
page = doc[page_num] page = doc[page_num]
text = page.get_text() text = page.get_text().strip()
if text.strip(): if text:
pages.append(PageText(page=page_num + 1, text=text))
text_parts.append(f"[Page {page_num + 1}]\n{text}") text_parts.append(f"[Page {page_num + 1}]\n{text}")
doc.close() doc.close()
@ -95,7 +98,8 @@ class PDFParser(DocumentParser):
metadata={ metadata={
"format": "pdf", "format": "pdf",
"page_count": page_count, "page_count": page_count,
} },
pages=pages,
) )
except DocumentParseException: except DocumentParseException:
@ -156,6 +160,7 @@ class PDFPlumberParser(DocumentParser):
pdfplumber = self._get_pdfplumber() pdfplumber = self._get_pdfplumber()
try: try:
pages: list[PageText] = []
text_parts = [] text_parts = []
page_count = 0 page_count = 0
@ -171,7 +176,9 @@ class PDFPlumberParser(DocumentParser):
table_text = self._format_table(table) table_text = self._format_table(table)
text += f"\n\n{table_text}" text += f"\n\n{table_text}"
if text.strip(): text = text.strip()
if text:
pages.append(PageText(page=page_num + 1, text=text))
text_parts.append(f"[Page {page_num + 1}]\n{text}") text_parts.append(f"[Page {page_num + 1}]\n{text}")
full_text = "\n\n".join(text_parts) full_text = "\n\n".join(text_parts)
@ -191,7 +198,8 @@ class PDFPlumberParser(DocumentParser):
"format": "pdf", "format": "pdf",
"parser": "pdfplumber", "parser": "pdfplumber",
"page_count": page_count, "page_count": page_count,
} },
pages=pages,
) )
except DocumentParseException: except DocumentParseException: