ai-robot-core/ai-service/scripts/profile_detailed.py

205 lines
6.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
详细性能分析 - 确认每个环节的耗时
"""
import asyncio
import sys
import time
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool, KbSearchDynamicConfig
from app.core.config import get_settings
from app.core.qdrant_client import get_qdrant_client
from app.services.embedding import get_embedding_provider
async def profile_detailed():
"""详细分析每个环节的耗时"""
settings = get_settings()
print("=" * 80)
print("详细性能分析")
print("=" * 80)
query = "三年级语文学习"
tenant_id = "szmp@ash@2026"
metadata_filter = {"grade": "三年级", "subject": "语文"}
# 1. Embedding 生成(应该已预初始化)
print("\n📊 1. Embedding 生成")
print("-" * 80)
start = time.time()
embedding_service = await get_embedding_provider()
init_time = (time.time() - start) * 1000
start = time.time()
embedding_result = await embedding_service.embed_query(query)
embed_time = (time.time() - start) * 1000
# 获取 embedding 向量
if hasattr(embedding_result, 'embedding_full'):
query_vector = embedding_result.embedding_full
elif hasattr(embedding_result, 'embedding'):
query_vector = embedding_result.embedding
else:
query_vector = embedding_result
print(f" 获取服务实例: {init_time:.2f} ms")
print(f" Embedding 生成: {embed_time:.2f} ms")
print(f" 向量维度: {len(query_vector)}")
# 2. 获取 collections 列表(带缓存)
print("\n📊 2. 获取 collections 列表")
print("-" * 80)
client = await get_qdrant_client()
qdrant_client = await client.get_client()
start = time.time()
from app.services.metadata_cache_service import get_metadata_cache_service
cache_service = await get_metadata_cache_service()
cache_key = f"collections:{tenant_id}"
# 尝试从缓存获取
redis_client = await cache_service._get_redis()
cache_hit = False
if redis_client and cache_service._enabled:
cached = await redis_client.get(cache_key)
if cached:
import json
tenant_collections = json.loads(cached)
cache_hit = True
cache_time = (time.time() - start) * 1000
print(f" ✅ 缓存命中: {cache_time:.2f} ms")
print(f" Collections: {tenant_collections}")
if not cache_hit:
import json
# 从 Qdrant 查询
start = time.time()
collections = await qdrant_client.get_collections()
safe_tenant_id = tenant_id.replace('@', '_')
prefix = f"kb_{safe_tenant_id}"
tenant_collections = [
c.name for c in collections.collections
if c.name.startswith(prefix)
]
tenant_collections.sort()
db_time = (time.time() - start) * 1000
print(f" ❌ 缓存未命中,从 Qdrant 查询: {db_time:.2f} ms")
print(f" Collections: {tenant_collections}")
# 缓存结果
if redis_client and cache_service._enabled:
await redis_client.setex(cache_key, 300, json.dumps(tenant_collections))
print(f" 已缓存到 Redis")
# 3. Qdrant 搜索(每个 collection
print("\n📊 3. Qdrant 搜索")
print("-" * 80)
from qdrant_client.models import FieldCondition, Filter, MatchValue
# 构建 filter
start = time.time()
must_conditions = []
for key, value in metadata_filter.items():
field_path = f"metadata.{key}"
condition = FieldCondition(
key=field_path,
match=MatchValue(value=value),
)
must_conditions.append(condition)
qdrant_filter = Filter(must=must_conditions) if must_conditions else None
filter_time = (time.time() - start) * 1000
print(f" 构建 filter: {filter_time:.2f} ms")
# 逐个 collection 搜索
total_search_time = 0
for collection_name in tenant_collections:
print(f"\n Collection: {collection_name}")
# 检查是否存在
start = time.time()
exists = await qdrant_client.collection_exists(collection_name)
check_time = (time.time() - start) * 1000
print(f" 检查存在: {check_time:.2f} ms")
if not exists:
print(f" ❌ 不存在")
continue
# 搜索
start = time.time()
try:
results = await qdrant_client.query_points(
collection_name=collection_name,
query=query_vector,
using="full",
limit=5,
score_threshold=0.5,
query_filter=qdrant_filter,
)
search_time = (time.time() - start) * 1000
total_search_time += search_time
print(f" 搜索时间: {search_time:.2f} ms")
print(f" 结果数: {len(results.points)}")
except Exception as e:
print(f" ❌ 搜索失败: {e}")
print(f"\n 总搜索时间: {total_search_time:.2f} ms")
# 4. 完整 KB Search 流程
print("\n📊 4. 完整 KB Search 流程")
print("-" * 80)
engine = create_async_engine(settings.database_url)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with async_session() as session:
config = KbSearchDynamicConfig(
enabled=True,
top_k=5,
timeout_ms=30000,
min_score_threshold=0.5,
)
tool = KbSearchDynamicTool(session=session, config=config)
start = time.time()
result = await tool.execute(
query=query,
tenant_id=tenant_id,
scene="学习方案",
top_k=5,
context=metadata_filter,
)
total_time = (time.time() - start) * 1000
print(f" 总耗时: {total_time:.2f} ms")
print(f" 工具内部耗时: {result.duration_ms} ms")
print(f" 结果数: {len(result.hits)}")
# 5. 总结
print("\n" + "=" * 80)
print("📈 耗时总结")
print("=" * 80)
print(f"\n各环节耗时:")
print(f" Embedding 获取服务: {init_time:.2f} ms")
print(f" Embedding 生成: {embed_time:.2f} ms")
print(f" Collections 获取: {cache_time if cache_hit else db_time:.2f} ms")
print(f" Filter 构建: {filter_time:.2f} ms")
print(f" Qdrant 搜索: {total_search_time:.2f} ms")
print(f" 完整流程: {total_time:.2f} ms")
other_time = total_time - embed_time - (cache_time if cache_hit else db_time) - filter_time - total_search_time
print(f" 其他开销: {other_time:.2f} ms")
print("\n" + "=" * 80)
if __name__ == "__main__":
asyncio.run(profile_detailed())