306 lines
12 KiB
Python
306 lines
12 KiB
Python
"""
|
||
恢复处理中断的索引任务
|
||
用于服务重启后继续处理pending/processing状态的任务
|
||
"""
|
||
|
||
import asyncio
|
||
import logging
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
# 添加项目根目录到路径
|
||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||
|
||
from sqlalchemy import select
|
||
from app.core.database import async_session_maker
|
||
from app.models.entities import IndexJob, Document, IndexJobStatus, DocumentStatus
|
||
from app.api.admin.kb import _index_document
|
||
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
async def resume_pending_jobs():
|
||
"""恢复所有pending和processing状态的任务"""
|
||
async with async_session_maker() as session:
|
||
# 查询所有未完成的任务
|
||
result = await session.execute(
|
||
select(IndexJob).where(
|
||
IndexJob.status.in_([IndexJobStatus.PENDING.value, IndexJobStatus.PROCESSING.value])
|
||
)
|
||
)
|
||
pending_jobs = result.scalars().all()
|
||
|
||
if not pending_jobs:
|
||
logger.info("没有需要恢复的任务")
|
||
return
|
||
|
||
logger.info(f"发现 {len(pending_jobs)} 个未完成的任务")
|
||
|
||
for job in pending_jobs:
|
||
try:
|
||
# 获取关联的文档
|
||
doc_result = await session.execute(
|
||
select(Document).where(Document.id == job.doc_id)
|
||
)
|
||
doc = doc_result.scalar_one_or_none()
|
||
|
||
if not doc:
|
||
logger.error(f"找不到文档: {job.doc_id}")
|
||
continue
|
||
|
||
if not doc.file_path or not Path(doc.file_path).exists():
|
||
logger.error(f"文档文件不存在: {doc.file_path}")
|
||
# 标记为失败
|
||
job.status = IndexJobStatus.FAILED.value
|
||
job.error_msg = "文档文件不存在"
|
||
doc.status = DocumentStatus.FAILED.value
|
||
doc.error_msg = "文档文件不存在"
|
||
await session.commit()
|
||
continue
|
||
|
||
logger.info(f"恢复处理: job_id={job.id}, doc_id={doc.id}, file={doc.file_name}")
|
||
|
||
# 读取文件内容
|
||
with open(doc.file_path, 'rb') as f:
|
||
file_content = f.read()
|
||
|
||
# 重置任务状态为pending
|
||
job.status = IndexJobStatus.PENDING.value
|
||
job.progress = 0
|
||
job.error_msg = None
|
||
await session.commit()
|
||
|
||
# 启动后台任务处理
|
||
# 注意:这里我们直接调用,不使用background_tasks
|
||
await process_job(
|
||
tenant_id=job.tenant_id,
|
||
kb_id=doc.kb_id,
|
||
job_id=str(job.id),
|
||
doc_id=str(doc.id),
|
||
file_content=file_content,
|
||
filename=doc.file_name,
|
||
metadata=doc.doc_metadata or {}
|
||
)
|
||
|
||
logger.info(f"任务处理完成: job_id={job.id}")
|
||
|
||
except Exception as e:
|
||
logger.error(f"处理任务失败: job_id={job.id}, error={e}")
|
||
# 标记为失败
|
||
job.status = IndexJobStatus.FAILED.value
|
||
job.error_msg = str(e)
|
||
if doc:
|
||
doc.status = DocumentStatus.FAILED.value
|
||
doc.error_msg = str(e)
|
||
await session.commit()
|
||
|
||
logger.info("所有任务处理完成")
|
||
|
||
|
||
async def process_job(tenant_id: str, kb_id: str, job_id: str, doc_id: str,
|
||
file_content: bytes, filename: str, metadata: dict):
|
||
"""
|
||
处理单个索引任务
|
||
复制自 _index_document 函数
|
||
"""
|
||
import tempfile
|
||
from pathlib import Path
|
||
|
||
from qdrant_client.models import PointStruct
|
||
|
||
from app.core.qdrant_client import get_qdrant_client
|
||
from app.services.document import DocumentParseException, UnsupportedFormatError, parse_document
|
||
from app.services.embedding import get_embedding_provider
|
||
from app.services.kb import KBService
|
||
from app.api.admin.kb import chunk_text_by_lines, TextChunk
|
||
|
||
logger.info(f"[RESUME] Starting indexing: tenant={tenant_id}, kb_id={kb_id}, job_id={job_id}, doc_id={doc_id}")
|
||
|
||
async with async_session_maker() as session:
|
||
kb_service = KBService(session)
|
||
try:
|
||
await kb_service.update_job_status(
|
||
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=10
|
||
)
|
||
await session.commit()
|
||
|
||
parse_result = None
|
||
text = None
|
||
file_ext = Path(filename or "").suffix.lower()
|
||
logger.info(f"[RESUME] File extension: {file_ext}, content size: {len(file_content)} bytes")
|
||
|
||
text_extensions = {".txt", ".md", ".markdown", ".rst", ".log", ".json", ".xml", ".yaml", ".yml"}
|
||
|
||
if file_ext in text_extensions or not file_ext:
|
||
logger.info("[RESUME] Treating as text file")
|
||
text = None
|
||
for encoding in ["utf-8", "gbk", "gb2312", "gb18030", "big5", "utf-16", "latin-1"]:
|
||
try:
|
||
text = file_content.decode(encoding)
|
||
logger.info(f"[RESUME] Successfully decoded with encoding: {encoding}")
|
||
break
|
||
except (UnicodeDecodeError, LookupError):
|
||
continue
|
||
|
||
if text is None:
|
||
text = file_content.decode("utf-8", errors="replace")
|
||
else:
|
||
logger.info("[RESUME] Binary file detected, will parse with document parser")
|
||
await kb_service.update_job_status(
|
||
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=15
|
||
)
|
||
await session.commit()
|
||
|
||
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file:
|
||
tmp_file.write(file_content)
|
||
tmp_path = tmp_file.name
|
||
|
||
logger.info(f"[RESUME] Temp file created: {tmp_path}")
|
||
|
||
try:
|
||
logger.info(f"[RESUME] Starting document parsing for {file_ext}...")
|
||
parse_result = parse_document(tmp_path)
|
||
text = parse_result.text
|
||
logger.info(
|
||
f"[RESUME] Parsed document SUCCESS: {filename}, chars={len(text)}"
|
||
)
|
||
except UnsupportedFormatError as e:
|
||
logger.error(f"[RESUME] UnsupportedFormatError: {e}")
|
||
text = file_content.decode("utf-8", errors="ignore")
|
||
except DocumentParseException as e:
|
||
logger.error(f"[RESUME] DocumentParseException: {e}")
|
||
text = file_content.decode("utf-8", errors="ignore")
|
||
except Exception as e:
|
||
logger.error(f"[RESUME] Unexpected parsing error: {type(e).__name__}: {e}")
|
||
text = file_content.decode("utf-8", errors="ignore")
|
||
finally:
|
||
Path(tmp_path).unlink(missing_ok=True)
|
||
|
||
logger.info(f"[RESUME] Final text length: {len(text)} chars")
|
||
|
||
await kb_service.update_job_status(
|
||
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=20
|
||
)
|
||
await session.commit()
|
||
|
||
logger.info("[RESUME] Getting embedding provider...")
|
||
embedding_provider = await get_embedding_provider()
|
||
logger.info(f"[RESUME] Embedding provider: {type(embedding_provider).__name__}")
|
||
|
||
all_chunks: list[TextChunk] = []
|
||
|
||
if parse_result and parse_result.pages:
|
||
logger.info(f"[RESUME] PDF with {len(parse_result.pages)} pages")
|
||
for page in parse_result.pages:
|
||
page_chunks = chunk_text_by_lines(
|
||
page.text,
|
||
min_line_length=10,
|
||
source=filename,
|
||
)
|
||
for pc in page_chunks:
|
||
pc.page = page.page
|
||
all_chunks.extend(page_chunks)
|
||
else:
|
||
logger.info("[RESUME] Using line-based chunking")
|
||
all_chunks = chunk_text_by_lines(
|
||
text,
|
||
min_line_length=10,
|
||
source=filename,
|
||
)
|
||
|
||
logger.info(f"[RESUME] Total chunks: {len(all_chunks)}")
|
||
|
||
qdrant = await get_qdrant_client()
|
||
await qdrant.ensure_kb_collection_exists(tenant_id, kb_id, use_multi_vector=True)
|
||
|
||
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
|
||
use_multi_vector = isinstance(embedding_provider, NomicEmbeddingProvider)
|
||
logger.info(f"[RESUME] Using multi-vector format: {use_multi_vector}")
|
||
|
||
import uuid
|
||
points = []
|
||
total_chunks = len(all_chunks)
|
||
doc_metadata = metadata or {}
|
||
|
||
for i, chunk in enumerate(all_chunks):
|
||
payload = {
|
||
"text": chunk.text,
|
||
"source": doc_id,
|
||
"kb_id": kb_id,
|
||
"chunk_index": i,
|
||
"start_token": chunk.start_token,
|
||
"end_token": chunk.end_token,
|
||
"metadata": doc_metadata,
|
||
}
|
||
if chunk.page is not None:
|
||
payload["page"] = chunk.page
|
||
if chunk.source:
|
||
payload["filename"] = chunk.source
|
||
|
||
if use_multi_vector:
|
||
embedding_result = await embedding_provider.embed_document(chunk.text)
|
||
points.append({
|
||
"id": str(uuid.uuid4()),
|
||
"vector": {
|
||
"full": embedding_result.embedding_full,
|
||
"dim_256": embedding_result.embedding_256,
|
||
"dim_512": embedding_result.embedding_512,
|
||
},
|
||
"payload": payload,
|
||
})
|
||
else:
|
||
embedding = await embedding_provider.embed(chunk.text)
|
||
points.append(
|
||
PointStruct(
|
||
id=str(uuid.uuid4()),
|
||
vector=embedding,
|
||
payload=payload,
|
||
)
|
||
)
|
||
|
||
progress = 20 + int((i + 1) / total_chunks * 70)
|
||
if i % 10 == 0 or i == total_chunks - 1:
|
||
await kb_service.update_job_status(
|
||
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=progress
|
||
)
|
||
await session.commit()
|
||
|
||
if points:
|
||
logger.info(f"[RESUME] Upserting {len(points)} vectors to Qdrant...")
|
||
if use_multi_vector:
|
||
await qdrant.upsert_multi_vector(tenant_id, points, kb_id=kb_id)
|
||
else:
|
||
await qdrant.upsert_vectors(tenant_id, points, kb_id=kb_id)
|
||
|
||
await kb_service.update_job_status(
|
||
tenant_id, job_id, IndexJobStatus.COMPLETED.value, progress=100
|
||
)
|
||
await session.commit()
|
||
|
||
logger.info(
|
||
f"[RESUME] COMPLETED: tenant={tenant_id}, kb_id={kb_id}, "
|
||
f"job_id={job_id}, chunks={len(all_chunks)}"
|
||
)
|
||
|
||
except Exception as e:
|
||
import traceback
|
||
logger.error(f"[RESUME] FAILED: {e}\n{traceback.format_exc()}")
|
||
await session.rollback()
|
||
async with async_session_maker() as error_session:
|
||
kb_service = KBService(error_session)
|
||
await kb_service.update_job_status(
|
||
tenant_id, job_id, IndexJobStatus.FAILED.value,
|
||
progress=0, error_msg=str(e)
|
||
)
|
||
await error_session.commit()
|
||
|
||
|
||
if __name__ == "__main__":
|
||
logger.info("开始恢复索引任务...")
|
||
asyncio.run(resume_pending_jobs())
|
||
logger.info("恢复脚本执行完成")
|