ai-robot-core/ai-service/scripts/resume_index_jobs.py

306 lines
12 KiB
Python
Raw Permalink Normal View History

"""
恢复处理中断的索引任务
用于服务重启后继续处理pending/processing状态的任务
"""
import asyncio
import logging
import sys
from pathlib import Path
# 添加项目根目录到路径
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy import select
from app.core.database import async_session_maker
from app.models.entities import IndexJob, Document, IndexJobStatus, DocumentStatus
from app.api.admin.kb import _index_document
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)
async def resume_pending_jobs():
"""恢复所有pending和processing状态的任务"""
async with async_session_maker() as session:
# 查询所有未完成的任务
result = await session.execute(
select(IndexJob).where(
IndexJob.status.in_([IndexJobStatus.PENDING.value, IndexJobStatus.PROCESSING.value])
)
)
pending_jobs = result.scalars().all()
if not pending_jobs:
logger.info("没有需要恢复的任务")
return
logger.info(f"发现 {len(pending_jobs)} 个未完成的任务")
for job in pending_jobs:
try:
# 获取关联的文档
doc_result = await session.execute(
select(Document).where(Document.id == job.doc_id)
)
doc = doc_result.scalar_one_or_none()
if not doc:
logger.error(f"找不到文档: {job.doc_id}")
continue
if not doc.file_path or not Path(doc.file_path).exists():
logger.error(f"文档文件不存在: {doc.file_path}")
# 标记为失败
job.status = IndexJobStatus.FAILED.value
job.error_msg = "文档文件不存在"
doc.status = DocumentStatus.FAILED.value
doc.error_msg = "文档文件不存在"
await session.commit()
continue
logger.info(f"恢复处理: job_id={job.id}, doc_id={doc.id}, file={doc.file_name}")
# 读取文件内容
with open(doc.file_path, 'rb') as f:
file_content = f.read()
# 重置任务状态为pending
job.status = IndexJobStatus.PENDING.value
job.progress = 0
job.error_msg = None
await session.commit()
# 启动后台任务处理
# 注意这里我们直接调用不使用background_tasks
await process_job(
tenant_id=job.tenant_id,
kb_id=doc.kb_id,
job_id=str(job.id),
doc_id=str(doc.id),
file_content=file_content,
filename=doc.file_name,
metadata=doc.doc_metadata or {}
)
logger.info(f"任务处理完成: job_id={job.id}")
except Exception as e:
logger.error(f"处理任务失败: job_id={job.id}, error={e}")
# 标记为失败
job.status = IndexJobStatus.FAILED.value
job.error_msg = str(e)
if doc:
doc.status = DocumentStatus.FAILED.value
doc.error_msg = str(e)
await session.commit()
logger.info("所有任务处理完成")
async def process_job(tenant_id: str, kb_id: str, job_id: str, doc_id: str,
file_content: bytes, filename: str, metadata: dict):
"""
处理单个索引任务
复制自 _index_document 函数
"""
import tempfile
from pathlib import Path
from qdrant_client.models import PointStruct
from app.core.qdrant_client import get_qdrant_client
from app.services.document import DocumentParseException, UnsupportedFormatError, parse_document
from app.services.embedding import get_embedding_provider
from app.services.kb import KBService
from app.api.admin.kb import chunk_text_by_lines, TextChunk
logger.info(f"[RESUME] Starting indexing: tenant={tenant_id}, kb_id={kb_id}, job_id={job_id}, doc_id={doc_id}")
async with async_session_maker() as session:
kb_service = KBService(session)
try:
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=10
)
await session.commit()
parse_result = None
text = None
file_ext = Path(filename or "").suffix.lower()
logger.info(f"[RESUME] File extension: {file_ext}, content size: {len(file_content)} bytes")
text_extensions = {".txt", ".md", ".markdown", ".rst", ".log", ".json", ".xml", ".yaml", ".yml"}
if file_ext in text_extensions or not file_ext:
logger.info("[RESUME] Treating as text file")
text = None
for encoding in ["utf-8", "gbk", "gb2312", "gb18030", "big5", "utf-16", "latin-1"]:
try:
text = file_content.decode(encoding)
logger.info(f"[RESUME] Successfully decoded with encoding: {encoding}")
break
except (UnicodeDecodeError, LookupError):
continue
if text is None:
text = file_content.decode("utf-8", errors="replace")
else:
logger.info("[RESUME] Binary file detected, will parse with document parser")
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=15
)
await session.commit()
with tempfile.NamedTemporaryFile(delete=False, suffix=file_ext) as tmp_file:
tmp_file.write(file_content)
tmp_path = tmp_file.name
logger.info(f"[RESUME] Temp file created: {tmp_path}")
try:
logger.info(f"[RESUME] Starting document parsing for {file_ext}...")
parse_result = parse_document(tmp_path)
text = parse_result.text
logger.info(
f"[RESUME] Parsed document SUCCESS: {filename}, chars={len(text)}"
)
except UnsupportedFormatError as e:
logger.error(f"[RESUME] UnsupportedFormatError: {e}")
text = file_content.decode("utf-8", errors="ignore")
except DocumentParseException as e:
logger.error(f"[RESUME] DocumentParseException: {e}")
text = file_content.decode("utf-8", errors="ignore")
except Exception as e:
logger.error(f"[RESUME] Unexpected parsing error: {type(e).__name__}: {e}")
text = file_content.decode("utf-8", errors="ignore")
finally:
Path(tmp_path).unlink(missing_ok=True)
logger.info(f"[RESUME] Final text length: {len(text)} chars")
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=20
)
await session.commit()
logger.info("[RESUME] Getting embedding provider...")
embedding_provider = await get_embedding_provider()
logger.info(f"[RESUME] Embedding provider: {type(embedding_provider).__name__}")
all_chunks: list[TextChunk] = []
if parse_result and parse_result.pages:
logger.info(f"[RESUME] PDF with {len(parse_result.pages)} pages")
for page in parse_result.pages:
page_chunks = chunk_text_by_lines(
page.text,
min_line_length=10,
source=filename,
)
for pc in page_chunks:
pc.page = page.page
all_chunks.extend(page_chunks)
else:
logger.info("[RESUME] Using line-based chunking")
all_chunks = chunk_text_by_lines(
text,
min_line_length=10,
source=filename,
)
logger.info(f"[RESUME] Total chunks: {len(all_chunks)}")
qdrant = await get_qdrant_client()
await qdrant.ensure_kb_collection_exists(tenant_id, kb_id, use_multi_vector=True)
from app.services.embedding.nomic_provider import NomicEmbeddingProvider
use_multi_vector = isinstance(embedding_provider, NomicEmbeddingProvider)
logger.info(f"[RESUME] Using multi-vector format: {use_multi_vector}")
import uuid
points = []
total_chunks = len(all_chunks)
doc_metadata = metadata or {}
for i, chunk in enumerate(all_chunks):
payload = {
"text": chunk.text,
"source": doc_id,
"kb_id": kb_id,
"chunk_index": i,
"start_token": chunk.start_token,
"end_token": chunk.end_token,
"metadata": doc_metadata,
}
if chunk.page is not None:
payload["page"] = chunk.page
if chunk.source:
payload["filename"] = chunk.source
if use_multi_vector:
embedding_result = await embedding_provider.embed_document(chunk.text)
points.append({
"id": str(uuid.uuid4()),
"vector": {
"full": embedding_result.embedding_full,
"dim_256": embedding_result.embedding_256,
"dim_512": embedding_result.embedding_512,
},
"payload": payload,
})
else:
embedding = await embedding_provider.embed(chunk.text)
points.append(
PointStruct(
id=str(uuid.uuid4()),
vector=embedding,
payload=payload,
)
)
progress = 20 + int((i + 1) / total_chunks * 70)
if i % 10 == 0 or i == total_chunks - 1:
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.PROCESSING.value, progress=progress
)
await session.commit()
if points:
logger.info(f"[RESUME] Upserting {len(points)} vectors to Qdrant...")
if use_multi_vector:
await qdrant.upsert_multi_vector(tenant_id, points, kb_id=kb_id)
else:
await qdrant.upsert_vectors(tenant_id, points, kb_id=kb_id)
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.COMPLETED.value, progress=100
)
await session.commit()
logger.info(
f"[RESUME] COMPLETED: tenant={tenant_id}, kb_id={kb_id}, "
f"job_id={job_id}, chunks={len(all_chunks)}"
)
except Exception as e:
import traceback
logger.error(f"[RESUME] FAILED: {e}\n{traceback.format_exc()}")
await session.rollback()
async with async_session_maker() as error_session:
kb_service = KBService(error_session)
await kb_service.update_job_status(
tenant_id, job_id, IndexJobStatus.FAILED.value,
progress=0, error_msg=str(e)
)
await error_session.commit()
if __name__ == "__main__":
logger.info("开始恢复索引任务...")
asyncio.run(resume_pending_jobs())
logger.info("恢复脚本执行完成")