ai-robot-core/ai-service/scripts/profile_kb_search.py

260 lines
9.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
知识库检索性能分析脚本
详细分析每个环节的耗时
"""
import asyncio
import sys
import time
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool, KbSearchDynamicConfig
from app.services.retrieval.vector_retriever import VectorRetriever
from app.core.config import get_settings
from app.core.qdrant_client import QdrantClient
async def profile_embedding_generation(query: str):
"""分析 embedding 生成耗时"""
from app.services.embedding import get_embedding_provider
start = time.time()
embedding_service = await get_embedding_provider()
init_time = time.time() - start
start = time.time()
embedding = await embedding_service.embed_query(query)
embed_time = time.time() - start
# 获取 embedding 向量(兼容不同 provider
if hasattr(embedding, 'embedding_full'):
vector = embedding.embedding_full
elif hasattr(embedding, 'embedding'):
vector = embedding.embedding
else:
vector = embedding
return {
"init_time_ms": init_time * 1000,
"embed_time_ms": embed_time * 1000,
"dimension": len(vector),
}
async def profile_qdrant_search(tenant_id: str, query_vector: list, metadata_filter: dict = None):
"""分析 Qdrant 搜索耗时"""
from app.core.qdrant_client import get_qdrant_client
client = await get_qdrant_client()
# 获取 collections
start = time.time()
qdrant_client = await client.get_client()
collections = await qdrant_client.get_collections()
safe_tenant_id = tenant_id.replace('@', '_')
prefix = f"kb_{safe_tenant_id}"
tenant_collections = [
c.name for c in collections.collections
if c.name.startswith(prefix)
]
list_collections_time = time.time() - start
# 逐个 collection 搜索
collection_times = []
for collection_name in tenant_collections:
start = time.time()
exists = await qdrant_client.collection_exists(collection_name)
check_time = time.time() - start
if not exists:
collection_times.append({
"collection": collection_name,
"exists": False,
"time_ms": check_time * 1000,
})
continue
start = time.time()
# 构建 filter
qdrant_filter = None
if metadata_filter:
from qdrant_client.models import FieldCondition, Filter, MatchValue
must_conditions = []
for key, value in metadata_filter.items():
field_path = f"metadata.{key}"
condition = FieldCondition(
key=field_path,
match=MatchValue(value=value),
)
must_conditions.append(condition)
qdrant_filter = Filter(must=must_conditions) if must_conditions else None
try:
results = await qdrant_client.query_points(
collection_name=collection_name,
query=query_vector,
using="full", # 使用 full 向量
limit=5,
score_threshold=0.5,
query_filter=qdrant_filter,
)
except Exception as e:
if "vector name" in str(e).lower():
# 尝试不使用 vector name
results = await qdrant_client.query_points(
collection_name=collection_name,
query=query_vector,
limit=5,
score_threshold=0.5,
query_filter=qdrant_filter,
)
else:
raise
search_time = time.time() - start
collection_times.append({
"collection": collection_name,
"exists": True,
"check_time_ms": check_time * 1000,
"search_time_ms": search_time * 1000,
"results_count": len(results.points),
})
return {
"list_collections_time_ms": list_collections_time * 1000,
"collections_count": len(tenant_collections),
"collection_times": collection_times,
}
async def profile_full_kb_search():
"""分析完整的知识库搜索流程"""
settings = get_settings()
print("=" * 80)
print("知识库检索性能分析")
print("=" * 80)
# 1. 分析 Embedding 生成
print("\n📊 1. Embedding 生成分析")
print("-" * 80)
query = "三年级语文学习"
embed_result = await profile_embedding_generation(query)
print(f" 初始化时间: {embed_result['init_time_ms']:.2f} ms")
print(f" Embedding 生成时间: {embed_result['embed_time_ms']:.2f} ms")
print(f" 向量维度: {embed_result['dimension']}")
# 2. 分析 Qdrant 搜索
print("\n📊 2. Qdrant 搜索分析")
print("-" * 80)
# 先生成 embedding
from app.services.embedding import get_embedding_provider
embedding_service = await get_embedding_provider()
embedding_result = await embedding_service.embed_query(query)
# 获取 embedding 向量(兼容不同 provider
if hasattr(embedding_result, 'embedding_full'):
query_vector = embedding_result.embedding_full
elif hasattr(embedding_result, 'embedding'):
query_vector = embedding_result.embedding
else:
query_vector = embedding_result
tenant_id = "szmp@ash@2026"
metadata_filter = {"grade": "三年级", "subject": "语文"}
qdrant_result = await profile_qdrant_search(tenant_id, query_vector, metadata_filter)
print(f" 获取 collections 列表时间: {qdrant_result['list_collections_time_ms']:.2f} ms")
print(f" Collections 数量: {qdrant_result['collections_count']}")
print(f"\n 各 Collection 搜索耗时:")
for ct in qdrant_result['collection_times']:
if ct['exists']:
print(f" - {ct['collection']}: {ct['search_time_ms']:.2f} ms (结果: {ct['results_count']} 条)")
else:
print(f" - {ct['collection']}: 不存在 ({ct['time_ms']:.2f} ms)")
total_search_time = sum(
ct.get('search_time_ms', 0) for ct in qdrant_result['collection_times']
)
print(f"\n 总搜索时间(串行): {total_search_time:.2f} ms")
# 3. 分析完整流程
print("\n📊 3. 完整 KB Search 流程分析")
print("-" * 80)
engine = create_async_engine(settings.database_url)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with async_session() as session:
config = KbSearchDynamicConfig(
enabled=True,
top_k=5,
timeout_ms=30000, # 30秒
min_score_threshold=0.5,
)
tool = KbSearchDynamicTool(session=session, config=config)
# 记录各阶段时间
stages = []
start_total = time.time()
# 执行搜索
start = time.time()
result = await tool.execute(
query=query,
tenant_id=tenant_id,
scene="学习方案",
top_k=5,
context=metadata_filter,
)
total_time = (time.time() - start_total) * 1000
print(f" 总耗时: {total_time:.2f} ms")
print(f" 结果: success={result.success}, hits={len(result.hits)}")
print(f" 工具内部耗时: {result.duration_ms} ms")
# 计算时间差(工具内部 vs 外部测量)
overhead = total_time - result.duration_ms
print(f" 额外开销(初始化等): {overhead:.2f} ms")
# 4. 性能瓶颈分析
print("\n📊 4. 性能瓶颈分析")
print("-" * 80)
embedding_time = embed_result['embed_time_ms']
qdrant_time = total_search_time
total_measured = embedding_time + qdrant_time
print(f" Embedding 生成: {embedding_time:.2f} ms ({embedding_time/total_measured*100:.1f}%)")
print(f" Qdrant 搜索: {qdrant_time:.2f} ms ({qdrant_time/total_measured*100:.1f}%)")
print(f" 其他开销: {total_time - total_measured:.2f} ms")
print("\n" + "=" * 80)
print("优化建议:")
print("=" * 80)
if embedding_time > 1000:
print(" ⚠️ Embedding 生成较慢,考虑:")
print(" - 使用更快的 embedding 模型")
print(" - 增加 embedding 服务缓存")
if qdrant_time > 1000:
print(" ⚠️ Qdrant 搜索较慢,考虑:")
print(" - 并行查询多个 collections")
print(" - 优化 Qdrant 索引")
print(" - 减少 collections 数量")
if len(qdrant_result['collection_times']) > 3:
print(f" ⚠️ Collections 数量较多 ({len(qdrant_result['collection_times'])} 个)")
print(" - 建议合并或归档空/少数据的 collections")
if __name__ == "__main__":
asyncio.run(profile_full_kb_search())