ai-robot-core/ai-service/scripts/profile_kb_search.py

260 lines
9.0 KiB
Python
Raw Permalink Normal View History

"""
知识库检索性能分析脚本
详细分析每个环节的耗时
"""
import asyncio
import sys
import time
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool, KbSearchDynamicConfig
from app.services.retrieval.vector_retriever import VectorRetriever
from app.core.config import get_settings
from app.core.qdrant_client import QdrantClient
async def profile_embedding_generation(query: str):
"""分析 embedding 生成耗时"""
from app.services.embedding import get_embedding_provider
start = time.time()
embedding_service = await get_embedding_provider()
init_time = time.time() - start
start = time.time()
embedding = await embedding_service.embed_query(query)
embed_time = time.time() - start
# 获取 embedding 向量(兼容不同 provider
if hasattr(embedding, 'embedding_full'):
vector = embedding.embedding_full
elif hasattr(embedding, 'embedding'):
vector = embedding.embedding
else:
vector = embedding
return {
"init_time_ms": init_time * 1000,
"embed_time_ms": embed_time * 1000,
"dimension": len(vector),
}
async def profile_qdrant_search(tenant_id: str, query_vector: list, metadata_filter: dict = None):
"""分析 Qdrant 搜索耗时"""
from app.core.qdrant_client import get_qdrant_client
client = await get_qdrant_client()
# 获取 collections
start = time.time()
qdrant_client = await client.get_client()
collections = await qdrant_client.get_collections()
safe_tenant_id = tenant_id.replace('@', '_')
prefix = f"kb_{safe_tenant_id}"
tenant_collections = [
c.name for c in collections.collections
if c.name.startswith(prefix)
]
list_collections_time = time.time() - start
# 逐个 collection 搜索
collection_times = []
for collection_name in tenant_collections:
start = time.time()
exists = await qdrant_client.collection_exists(collection_name)
check_time = time.time() - start
if not exists:
collection_times.append({
"collection": collection_name,
"exists": False,
"time_ms": check_time * 1000,
})
continue
start = time.time()
# 构建 filter
qdrant_filter = None
if metadata_filter:
from qdrant_client.models import FieldCondition, Filter, MatchValue
must_conditions = []
for key, value in metadata_filter.items():
field_path = f"metadata.{key}"
condition = FieldCondition(
key=field_path,
match=MatchValue(value=value),
)
must_conditions.append(condition)
qdrant_filter = Filter(must=must_conditions) if must_conditions else None
try:
results = await qdrant_client.query_points(
collection_name=collection_name,
query=query_vector,
using="full", # 使用 full 向量
limit=5,
score_threshold=0.5,
query_filter=qdrant_filter,
)
except Exception as e:
if "vector name" in str(e).lower():
# 尝试不使用 vector name
results = await qdrant_client.query_points(
collection_name=collection_name,
query=query_vector,
limit=5,
score_threshold=0.5,
query_filter=qdrant_filter,
)
else:
raise
search_time = time.time() - start
collection_times.append({
"collection": collection_name,
"exists": True,
"check_time_ms": check_time * 1000,
"search_time_ms": search_time * 1000,
"results_count": len(results.points),
})
return {
"list_collections_time_ms": list_collections_time * 1000,
"collections_count": len(tenant_collections),
"collection_times": collection_times,
}
async def profile_full_kb_search():
"""分析完整的知识库搜索流程"""
settings = get_settings()
print("=" * 80)
print("知识库检索性能分析")
print("=" * 80)
# 1. 分析 Embedding 生成
print("\n📊 1. Embedding 生成分析")
print("-" * 80)
query = "三年级语文学习"
embed_result = await profile_embedding_generation(query)
print(f" 初始化时间: {embed_result['init_time_ms']:.2f} ms")
print(f" Embedding 生成时间: {embed_result['embed_time_ms']:.2f} ms")
print(f" 向量维度: {embed_result['dimension']}")
# 2. 分析 Qdrant 搜索
print("\n📊 2. Qdrant 搜索分析")
print("-" * 80)
# 先生成 embedding
from app.services.embedding import get_embedding_provider
embedding_service = await get_embedding_provider()
embedding_result = await embedding_service.embed_query(query)
# 获取 embedding 向量(兼容不同 provider
if hasattr(embedding_result, 'embedding_full'):
query_vector = embedding_result.embedding_full
elif hasattr(embedding_result, 'embedding'):
query_vector = embedding_result.embedding
else:
query_vector = embedding_result
tenant_id = "szmp@ash@2026"
metadata_filter = {"grade": "三年级", "subject": "语文"}
qdrant_result = await profile_qdrant_search(tenant_id, query_vector, metadata_filter)
print(f" 获取 collections 列表时间: {qdrant_result['list_collections_time_ms']:.2f} ms")
print(f" Collections 数量: {qdrant_result['collections_count']}")
print(f"\n 各 Collection 搜索耗时:")
for ct in qdrant_result['collection_times']:
if ct['exists']:
print(f" - {ct['collection']}: {ct['search_time_ms']:.2f} ms (结果: {ct['results_count']} 条)")
else:
print(f" - {ct['collection']}: 不存在 ({ct['time_ms']:.2f} ms)")
total_search_time = sum(
ct.get('search_time_ms', 0) for ct in qdrant_result['collection_times']
)
print(f"\n 总搜索时间(串行): {total_search_time:.2f} ms")
# 3. 分析完整流程
print("\n📊 3. 完整 KB Search 流程分析")
print("-" * 80)
engine = create_async_engine(settings.database_url)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with async_session() as session:
config = KbSearchDynamicConfig(
enabled=True,
top_k=5,
timeout_ms=30000, # 30秒
min_score_threshold=0.5,
)
tool = KbSearchDynamicTool(session=session, config=config)
# 记录各阶段时间
stages = []
start_total = time.time()
# 执行搜索
start = time.time()
result = await tool.execute(
query=query,
tenant_id=tenant_id,
scene="学习方案",
top_k=5,
context=metadata_filter,
)
total_time = (time.time() - start_total) * 1000
print(f" 总耗时: {total_time:.2f} ms")
print(f" 结果: success={result.success}, hits={len(result.hits)}")
print(f" 工具内部耗时: {result.duration_ms} ms")
# 计算时间差(工具内部 vs 外部测量)
overhead = total_time - result.duration_ms
print(f" 额外开销(初始化等): {overhead:.2f} ms")
# 4. 性能瓶颈分析
print("\n📊 4. 性能瓶颈分析")
print("-" * 80)
embedding_time = embed_result['embed_time_ms']
qdrant_time = total_search_time
total_measured = embedding_time + qdrant_time
print(f" Embedding 生成: {embedding_time:.2f} ms ({embedding_time/total_measured*100:.1f}%)")
print(f" Qdrant 搜索: {qdrant_time:.2f} ms ({qdrant_time/total_measured*100:.1f}%)")
print(f" 其他开销: {total_time - total_measured:.2f} ms")
print("\n" + "=" * 80)
print("优化建议:")
print("=" * 80)
if embedding_time > 1000:
print(" ⚠️ Embedding 生成较慢,考虑:")
print(" - 使用更快的 embedding 模型")
print(" - 增加 embedding 服务缓存")
if qdrant_time > 1000:
print(" ⚠️ Qdrant 搜索较慢,考虑:")
print(" - 并行查询多个 collections")
print(" - 优化 Qdrant 索引")
print(" - 减少 collections 数量")
if len(qdrant_result['collection_times']) > 3:
print(f" ⚠️ Collections 数量较多 ({len(qdrant_result['collection_times'])} 个)")
print(" - 建议合并或归档空/少数据的 collections")
if __name__ == "__main__":
asyncio.run(profile_full_kb_search())