205 lines
6.9 KiB
Python
205 lines
6.9 KiB
Python
"""
|
||
详细性能分析 - 确认每个环节的耗时
|
||
"""
|
||
|
||
import asyncio
|
||
import sys
|
||
import time
|
||
from pathlib import Path
|
||
|
||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||
|
||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||
from sqlalchemy.orm import sessionmaker
|
||
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool, KbSearchDynamicConfig
|
||
from app.core.config import get_settings
|
||
from app.core.qdrant_client import get_qdrant_client
|
||
from app.services.embedding import get_embedding_provider
|
||
|
||
|
||
async def profile_detailed():
|
||
"""详细分析每个环节的耗时"""
|
||
settings = get_settings()
|
||
|
||
print("=" * 80)
|
||
print("详细性能分析")
|
||
print("=" * 80)
|
||
|
||
query = "三年级语文学习"
|
||
tenant_id = "szmp@ash@2026"
|
||
metadata_filter = {"grade": "三年级", "subject": "语文"}
|
||
|
||
# 1. Embedding 生成(应该已预初始化)
|
||
print("\n📊 1. Embedding 生成")
|
||
print("-" * 80)
|
||
start = time.time()
|
||
embedding_service = await get_embedding_provider()
|
||
init_time = (time.time() - start) * 1000
|
||
|
||
start = time.time()
|
||
embedding_result = await embedding_service.embed_query(query)
|
||
embed_time = (time.time() - start) * 1000
|
||
|
||
# 获取 embedding 向量
|
||
if hasattr(embedding_result, 'embedding_full'):
|
||
query_vector = embedding_result.embedding_full
|
||
elif hasattr(embedding_result, 'embedding'):
|
||
query_vector = embedding_result.embedding
|
||
else:
|
||
query_vector = embedding_result
|
||
|
||
print(f" 获取服务实例: {init_time:.2f} ms")
|
||
print(f" Embedding 生成: {embed_time:.2f} ms")
|
||
print(f" 向量维度: {len(query_vector)}")
|
||
|
||
# 2. 获取 collections 列表(带缓存)
|
||
print("\n📊 2. 获取 collections 列表")
|
||
print("-" * 80)
|
||
client = await get_qdrant_client()
|
||
qdrant_client = await client.get_client()
|
||
|
||
start = time.time()
|
||
from app.services.metadata_cache_service import get_metadata_cache_service
|
||
cache_service = await get_metadata_cache_service()
|
||
cache_key = f"collections:{tenant_id}"
|
||
|
||
# 尝试从缓存获取
|
||
redis_client = await cache_service._get_redis()
|
||
cache_hit = False
|
||
if redis_client and cache_service._enabled:
|
||
cached = await redis_client.get(cache_key)
|
||
if cached:
|
||
import json
|
||
tenant_collections = json.loads(cached)
|
||
cache_hit = True
|
||
cache_time = (time.time() - start) * 1000
|
||
print(f" ✅ 缓存命中: {cache_time:.2f} ms")
|
||
print(f" Collections: {tenant_collections}")
|
||
|
||
if not cache_hit:
|
||
import json
|
||
# 从 Qdrant 查询
|
||
start = time.time()
|
||
collections = await qdrant_client.get_collections()
|
||
safe_tenant_id = tenant_id.replace('@', '_')
|
||
prefix = f"kb_{safe_tenant_id}"
|
||
tenant_collections = [
|
||
c.name for c in collections.collections
|
||
if c.name.startswith(prefix)
|
||
]
|
||
tenant_collections.sort()
|
||
db_time = (time.time() - start) * 1000
|
||
print(f" ❌ 缓存未命中,从 Qdrant 查询: {db_time:.2f} ms")
|
||
print(f" Collections: {tenant_collections}")
|
||
|
||
# 缓存结果
|
||
if redis_client and cache_service._enabled:
|
||
await redis_client.setex(cache_key, 300, json.dumps(tenant_collections))
|
||
print(f" 已缓存到 Redis")
|
||
|
||
# 3. Qdrant 搜索(每个 collection)
|
||
print("\n📊 3. Qdrant 搜索")
|
||
print("-" * 80)
|
||
from qdrant_client.models import FieldCondition, Filter, MatchValue
|
||
|
||
# 构建 filter
|
||
start = time.time()
|
||
must_conditions = []
|
||
for key, value in metadata_filter.items():
|
||
field_path = f"metadata.{key}"
|
||
condition = FieldCondition(
|
||
key=field_path,
|
||
match=MatchValue(value=value),
|
||
)
|
||
must_conditions.append(condition)
|
||
qdrant_filter = Filter(must=must_conditions) if must_conditions else None
|
||
filter_time = (time.time() - start) * 1000
|
||
print(f" 构建 filter: {filter_time:.2f} ms")
|
||
|
||
# 逐个 collection 搜索
|
||
total_search_time = 0
|
||
for collection_name in tenant_collections:
|
||
print(f"\n Collection: {collection_name}")
|
||
|
||
# 检查是否存在
|
||
start = time.time()
|
||
exists = await qdrant_client.collection_exists(collection_name)
|
||
check_time = (time.time() - start) * 1000
|
||
print(f" 检查存在: {check_time:.2f} ms")
|
||
|
||
if not exists:
|
||
print(f" ❌ 不存在")
|
||
continue
|
||
|
||
# 搜索
|
||
start = time.time()
|
||
try:
|
||
results = await qdrant_client.query_points(
|
||
collection_name=collection_name,
|
||
query=query_vector,
|
||
using="full",
|
||
limit=5,
|
||
score_threshold=0.5,
|
||
query_filter=qdrant_filter,
|
||
)
|
||
search_time = (time.time() - start) * 1000
|
||
total_search_time += search_time
|
||
print(f" 搜索时间: {search_time:.2f} ms")
|
||
print(f" 结果数: {len(results.points)}")
|
||
except Exception as e:
|
||
print(f" ❌ 搜索失败: {e}")
|
||
|
||
print(f"\n 总搜索时间: {total_search_time:.2f} ms")
|
||
|
||
# 4. 完整 KB Search 流程
|
||
print("\n📊 4. 完整 KB Search 流程")
|
||
print("-" * 80)
|
||
|
||
engine = create_async_engine(settings.database_url)
|
||
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||
|
||
async with async_session() as session:
|
||
config = KbSearchDynamicConfig(
|
||
enabled=True,
|
||
top_k=5,
|
||
timeout_ms=30000,
|
||
min_score_threshold=0.5,
|
||
)
|
||
|
||
tool = KbSearchDynamicTool(session=session, config=config)
|
||
|
||
start = time.time()
|
||
result = await tool.execute(
|
||
query=query,
|
||
tenant_id=tenant_id,
|
||
scene="学习方案",
|
||
top_k=5,
|
||
context=metadata_filter,
|
||
)
|
||
total_time = (time.time() - start) * 1000
|
||
|
||
print(f" 总耗时: {total_time:.2f} ms")
|
||
print(f" 工具内部耗时: {result.duration_ms} ms")
|
||
print(f" 结果数: {len(result.hits)}")
|
||
|
||
# 5. 总结
|
||
print("\n" + "=" * 80)
|
||
print("📈 耗时总结")
|
||
print("=" * 80)
|
||
print(f"\n各环节耗时:")
|
||
print(f" Embedding 获取服务: {init_time:.2f} ms")
|
||
print(f" Embedding 生成: {embed_time:.2f} ms")
|
||
print(f" Collections 获取: {cache_time if cache_hit else db_time:.2f} ms")
|
||
print(f" Filter 构建: {filter_time:.2f} ms")
|
||
print(f" Qdrant 搜索: {total_search_time:.2f} ms")
|
||
print(f" 完整流程: {total_time:.2f} ms")
|
||
|
||
other_time = total_time - embed_time - (cache_time if cache_hit else db_time) - filter_time - total_search_time
|
||
print(f" 其他开销: {other_time:.2f} ms")
|
||
|
||
print("\n" + "=" * 80)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(profile_detailed())
|