ai-robot-core/ai-service/scripts/profile_full_params.py

247 lines
8.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
详细分析完整参数查询的耗时
对比带 metadata_filter 和不带的区别
"""
import asyncio
import sys
import time
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool, KbSearchDynamicConfig
from app.core.config import get_settings
from app.core.qdrant_client import get_qdrant_client
from qdrant_client.models import FieldCondition, Filter, MatchValue
async def profile_step_by_step():
"""逐步分析完整参数查询的耗时"""
settings = get_settings()
print("=" * 80)
print("完整参数查询耗时分析")
print("=" * 80)
query = "三年级语文学习"
tenant_id = "szmp@ash@2026"
metadata_filter = {"grade": "三年级", "subject": "语文"}
# 1. Embedding 生成
print("\n📊 1. Embedding 生成")
print("-" * 80)
from app.services.embedding import get_embedding_provider
start = time.time()
embedding_service = await get_embedding_provider()
init_time = time.time() - start
start = time.time()
embedding_result = await embedding_service.embed_query(query)
embed_time = (time.time() - start) * 1000
# 获取 embedding 向量
if hasattr(embedding_result, 'embedding_full'):
query_vector = embedding_result.embedding_full
elif hasattr(embedding_result, 'embedding'):
query_vector = embedding_result.embedding
else:
query_vector = embedding_result
print(f" 初始化时间: {init_time * 1000:.2f} ms")
print(f" Embedding 生成: {embed_time:.2f} ms")
print(f" 向量维度: {len(query_vector)}")
# 2. 获取 collections 列表
print("\n📊 2. 获取 collections 列表")
print("-" * 80)
client = await get_qdrant_client()
qdrant_client = await client.get_client()
start = time.time()
collections = await qdrant_client.get_collections()
safe_tenant_id = tenant_id.replace('@', '_')
prefix = f"kb_{safe_tenant_id}"
tenant_collections = [
c.name for c in collections.collections
if c.name.startswith(prefix)
]
list_time = (time.time() - start) * 1000
print(f" 获取 collections: {list_time:.2f} ms")
print(f" Collections: {tenant_collections}")
# 3. 构建 metadata filter
print("\n📊 3. 构建 metadata filter")
print("-" * 80)
start = time.time()
must_conditions = []
for key, value in metadata_filter.items():
field_path = f"metadata.{key}"
condition = FieldCondition(
key=field_path,
match=MatchValue(value=value),
)
must_conditions.append(condition)
qdrant_filter = Filter(must=must_conditions) if must_conditions else None
filter_time = (time.time() - start) * 1000
print(f" 构建 filter: {filter_time:.2f} ms")
print(f" Filter: {qdrant_filter}")
# 4. 逐个 collection 搜索(带 filter
print("\n📊 4. Qdrant 搜索(带 metadata filter")
print("-" * 80)
total_search_time = 0
total_results = 0
for collection_name in tenant_collections:
print(f"\n Collection: {collection_name}")
# 检查是否存在
start = time.time()
exists = await qdrant_client.collection_exists(collection_name)
check_time = (time.time() - start) * 1000
print(f" 检查存在: {check_time:.2f} ms")
if not exists:
print(f" ❌ 不存在")
continue
# 搜索(带 filter
start = time.time()
try:
results = await qdrant_client.query_points(
collection_name=collection_name,
query=query_vector,
using="full",
limit=5,
score_threshold=0.5,
query_filter=qdrant_filter,
)
search_time = (time.time() - start) * 1000
total_search_time += search_time
total_results += len(results.points)
print(f" 搜索时间: {search_time:.2f} ms")
print(f" 结果数: {len(results.points)}")
except Exception as e:
print(f" ❌ 搜索失败: {e}")
print(f"\n 总搜索时间: {total_search_time:.2f} ms")
print(f" 总结果数: {total_results}")
# 5. 对比:不带 filter 的搜索
print("\n📊 5. Qdrant 搜索(不带 metadata filter对比")
print("-" * 80)
total_search_time_no_filter = 0
total_results_no_filter = 0
for collection_name in tenant_collections:
start = time.time()
try:
results = await qdrant_client.query_points(
collection_name=collection_name,
query=query_vector,
using="full",
limit=5,
score_threshold=0.5,
# 不带 filter
)
search_time = (time.time() - start) * 1000
total_search_time_no_filter += search_time
total_results_no_filter += len(results.points)
except Exception as e:
print(f" {collection_name}: 失败 {e}")
print(f" 总搜索时间(无 filter: {total_search_time_no_filter:.2f} ms")
print(f" 总结果数(无 filter: {total_results_no_filter}")
# 6. 完整 KB Search 流程
print("\n📊 6. 完整 KB Search 流程(带 context")
print("-" * 80)
engine = create_async_engine(settings.database_url)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with async_session() as session:
config = KbSearchDynamicConfig(
enabled=True,
top_k=5,
timeout_ms=30000,
min_score_threshold=0.5,
)
tool = KbSearchDynamicTool(session=session, config=config)
start = time.time()
result = await tool.execute(
query=query,
tenant_id=tenant_id,
scene="学习方案",
top_k=5,
context=metadata_filter,
)
total_time = (time.time() - start) * 1000
print(f" 总耗时: {total_time:.2f} ms")
print(f" 工具内部耗时: {result.duration_ms} ms")
print(f" 结果数: {len(result.hits)}")
print(f" 应用的 filter: {result.applied_filter}")
# 7. 对比:不带 context 的完整流程
print("\n📊 7. 完整 KB Search 流程(不带 context")
print("-" * 80)
async with async_session() as session:
tool = KbSearchDynamicTool(session=session, config=config)
start = time.time()
result = await tool.execute(
query=query,
tenant_id=tenant_id,
scene="学习方案",
top_k=5,
# 不带 context
)
total_time_no_context = (time.time() - start) * 1000
print(f" 总耗时: {total_time_no_context:.2f} ms")
print(f" 工具内部耗时: {result.duration_ms} ms")
print(f" 结果数: {len(result.hits)}")
# 8. 总结
print("\n" + "=" * 80)
print("📈 耗时分析总结")
print("=" * 80)
print(f"\n带 metadata filter:")
print(f" Embedding: {embed_time:.2f} ms")
print(f" 获取 collections: {list_time:.2f} ms")
print(f" Qdrant 搜索: {total_search_time:.2f} ms")
print(f" 完整流程: {total_time:.2f} ms")
print(f"\n不带 metadata filter:")
print(f" Qdrant 搜索: {total_search_time_no_filter:.2f} ms")
print(f" 完整流程: {total_time_no_context:.2f} ms")
print(f"\nMetadata filter 额外开销:")
print(f" Qdrant 搜索: {total_search_time - total_search_time_no_filter:.2f} ms")
print(f" 完整流程: {total_time - total_time_no_context:.2f} ms")
if total_search_time > total_search_time_no_filter:
print(f"\n⚠️ 带 filter 的搜索更慢,可能原因:")
print(f" - Filter 增加了索引查找的复杂度")
print(f" - 需要匹配 metadata 字段")
print(f" - 建议: 检查 Qdrant 的 payload 索引配置")
if __name__ == "__main__":
asyncio.run(profile_step_by_step())