ai-robot-core/ai-service/scripts/resume_index_jobs.py

306 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
恢复处理中断的索引任务
用于服务重启后继续处理pending/processing状态的任务
"""
import asyncio
import logging
import sys
from pathlib import Path
# 添加项目根目录到路径
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy import select
from app.core.database import async_session_maker
from app.models.entities import IndexJob, Document, IndexJobStatus, DocumentStatus
from app.api.admin.kb import _index_document
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
async def resume_pending_jobs():
"""恢复所有pending和processing状态的任务"""
async with async_session_maker() as session:
# 查询所有未完成的任务
result = await session.execute(
select(IndexJob).where(
IndexJob.status.in_([IndexJobStatus.PENDING.value, IndexJobStatus.PROCESSING.value])
)
)
pending_jobs = result.scalars().all()
if not pending_jobs:
logger.info("没有需要恢复的任务")
return
logger.info(f"发现 {len(pending_jobs)} 个未完成的任务")
for job in pending_jobs:
try:
# 获取关联的文档
doc_result = await session.execute(
select(Document).where(Document.id == job.doc_id)
)
doc = doc_result.scalar_one_or_none()
if not doc:
logger.error(f"找不到文档: {job.doc_id}")
continue
if not doc.file_path or not Path(doc.file_path).exists():
logger.error(f"文档文件不存在: {doc.file_path}")
# 标记为失败
job.status = IndexJobStatus.FAILED.value
job.error_msg = "文档文件不存在"
doc.status = DocumentStatus.FAILED.value
doc.error_msg = "文档文件不存在"
await session.commit()
continue
logger.info(f"恢复处理: job_id={job.id}, doc_id={doc.id}, file={doc.file_name}")
# 读取文件内容
with open(doc.file_path, 'rb') as f:
file_content = f.read()
# 重置任务状态为pending
job.status = IndexJobStatus.PENDING.value
job.progress = 0
job.error_msg = None
await session.commit()
# 启动后台任务处理
# 注意这里我们直接调用不使用background_tasks
await process_job(
tenant_id=job.tenant_id,
kb_id=doc.kb_id,
job_id=str(job.id),
doc_id=str(doc.id),
file_content=file_content,
filename=doc.file_name,
metadata=doc.doc_metadata or {}
)
logger.info(f"任务处理完成: job_id={job.id}")
except Exception as e:
logger.error(f"处理任务失败: job_id={job.id}, error={e}")
# 标记为失败
job.status = IndexJobStatus.FAILED.value
job.error_msg = str(e)
if doc:
doc.status = DocumentStatus.FAILED.value
doc.error_msg = str(e)
await session.commit()
logger.info("所有任务处理完成")
async def process_job(tenant_id: str, kb_id: str, job_id: str, doc_id: str,
file_content: bytes, filename: str, metadata: dict):
"""
处理单个索引任务
复制自 _index_document 函数
"""
import tempfile
from pathlib import Path
from qdrant_client.models import PointStruct
from app.core.qdrant_client import get_qdrant_client
from app.services.document import DocumentParseException, UnsupportedFormatError, parse_document
from app.services.embedding import get_embedding_provider
from app.services.kb import KBService
from app.api.admin.kb import chunk_text_by_lines, TextChunk
logger.info(f"[RESUME] Starting indexing: tenant={tenant_id}, kb_id={kb_id}, job_id={job_id}, doc_id={doc_id}")
async with async_session_maker() as session:
kb_service = KBService(session)
try:
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=10
)
await session.commit()
parse_result = None
text = None
file_ext = Path(filename or "").suffix.lower()
logger.info(f"[RESUME] File extension: {file_ext}, content size: {len(file_content)} bytes")
text_extensions = {".txt", ".md", ".markdown", ".rst", ".log", ".json", ".xml", ".yaml", ".yml"}
if file_ext in text_extensions or not file_ext:
logger.info("[RESUME] Treating as text file")
text = None
for encoding in ["utf-8", "gbk", "gb2312", "gb18030", "big5", "utf-16", "latin-1"]:
try:
text = file_content.decode(encoding)
logger.info(f"[RESUME] Successfully decoded with encoding: {encoding}")
break
except (UnicodeDecodeError, LookupError):
continue
if text is None:
text = file_content.decode("utf-8", errors="replace")
else:
logger.info("[RESUME] Binary file detected, will parse with document parser")
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=15
)
await session.commit()
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file:
tmp_file.write(file_content)
tmp_path = tmp_file.name
logger.info(f"[RESUME] Temp file created: {tmp_path}")
try:
logger.info(f"[RESUME] Starting document parsing for {file_ext}...")
parse_result = parse_document(tmp_path)
text = parse_result.text
logger.info(
f"[RESUME] Parsed document SUCCESS: {filename}, chars={len(text)}"
)
except UnsupportedFormatError as e:
logger.error(f"[RESUME] UnsupportedFormatError: {e}")
text = file_content.decode("utf-8", errors="ignore")
except DocumentParseException as e:
logger.error(f"[RESUME] DocumentParseException: {e}")
text = file_content.decode("utf-8", errors="ignore")
except Exception as e:
logger.error(f"[RESUME] Unexpected parsing error: {type(e).__name__}: {e}")
text = file_content.decode("utf-8", errors="ignore")
finally:
Path(tmp_path).unlink(missing_ok=True)
logger.info(f"[RESUME] Final text length: {len(text)} chars")
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=20
)
await session.commit()
logger.info("[RESUME] Getting embedding provider...")
embedding_provider = await get_embedding_provider()
logger.info(f"[RESUME] Embedding provider: {type(embedding_provider).__name__}")
all_chunks: list[TextChunk] = []
if parse_result and parse_result.pages:
logger.info(f"[RESUME] PDF with {len(parse_result.pages)} pages")
for page in parse_result.pages:
page_chunks = chunk_text_by_lines(
page.text,
min_line_length=10,
source=filename,
)
for pc in page_chunks:
pc.page = page.page
all_chunks.extend(page_chunks)
else:
logger.info("[RESUME] Using line-based chunking")
all_chunks = chunk_text_by_lines(
text,
min_line_length=10,
source=filename,
)
logger.info(f"[RESUME] Total chunks: {len(all_chunks)}")
qdrant = await get_qdrant_client()
await qdrant.ensure_kb_collection_exists(tenant_id, kb_id, use_multi_vector=True)
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
use_multi_vector = isinstance(embedding_provider, NomicEmbeddingProvider)
logger.info(f"[RESUME] Using multi-vector format: {use_multi_vector}")
import uuid
points = []
total_chunks = len(all_chunks)
doc_metadata = metadata or {}
for i, chunk in enumerate(all_chunks):
payload = {
"text": chunk.text,
"source": doc_id,
"kb_id": kb_id,
"chunk_index": i,
"start_token": chunk.start_token,
"end_token": chunk.end_token,
"metadata": doc_metadata,
}
if chunk.page is not None:
payload["page"] = chunk.page
if chunk.source:
payload["filename"] = chunk.source
if use_multi_vector:
embedding_result = await embedding_provider.embed_document(chunk.text)
points.append({
"id": str(uuid.uuid4()),
"vector": {
"full": embedding_result.embedding_full,
"dim_256": embedding_result.embedding_256,
"dim_512": embedding_result.embedding_512,
},
"payload": payload,
})
else:
embedding = await embedding_provider.embed(chunk.text)
points.append(
PointStruct(
id=str(uuid.uuid4()),
vector=embedding,
payload=payload,
)
)
progress = 20 + int((i + 1) / total_chunks * 70)
if i % 10 == 0 or i == total_chunks - 1:
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=progress
)
await session.commit()
if points:
logger.info(f"[RESUME] Upserting {len(points)} vectors to Qdrant...")
if use_multi_vector:
await qdrant.upsert_multi_vector(tenant_id, points, kb_id=kb_id)
else:
await qdrant.upsert_vectors(tenant_id, points, kb_id=kb_id)
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.COMPLETED.value, progress=100
)
await session.commit()
logger.info(
f"[RESUME] COMPLETED: tenant={tenant_id}, kb_id={kb_id}, "
f"job_id={job_id}, chunks={len(all_chunks)}"
)
except Exception as e:
import traceback
logger.error(f"[RESUME] FAILED: {e}\n{traceback.format_exc()}")
await session.rollback()
async with async_session_maker() as error_session:
kb_service = KBService(error_session)
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.FAILED.value,
progress=0, error_msg=str(e)
)
await error_session.commit()
if __name__ == "__main__":
logger.info("开始恢复索引任务...")
asyncio.run(resume_pending_jobs())
logger.info("恢复脚本执行完成")