ai-robot-core/ai-service/test_kb_search_with_metadat...

218 lines
6.7 KiB
Python
Raw Normal View History

"""
测试KB元数据过滤查询使用正确的知识库集合
"""
import asyncio
import json
from app.core.database import async_session_maker
from app.core.qdrant_client import get_qdrant_client
from app.services.embedding import get_embedding_provider
async def test_kb_search_with_filter():
"""测试带元数据过滤的KB检索"""
tenant_id = "szmp@ash@2026"
kb_id = "30c19c84-8f69-4768-9d23-7f4a5bc3627a" # 客服咨询知识库
query = "初二学生学习困难"
# 测试过滤器
metadata_filter = {
"grade": {"$eq": "初二"},
"kb_scene": {"$eq": "痛点"}
}
print("=" * 60)
print("测试带元数据过滤的KB检索")
print("=" * 60)
print(f"租户: {tenant_id}")
print(f"知识库: {kb_id}")
print(f"查询: {query}")
print(f"过滤器: {json.dumps(metadata_filter, ensure_ascii=False)}")
print()
# 1. 获取 embedding
print("[1] 生成查询向量...")
embedding_provider = await get_embedding_provider()
query_vector = await embedding_provider.embed(query)
print(f" 向量维度: {len(query_vector)}")
print()
# 2. 搜索 Qdrant
print("[2] 搜索 Qdrant...")
client = await get_qdrant_client()
# 使用知识库特定集合
collection_name = client.get_kb_collection_name(tenant_id, kb_id)
print(f" 集合: {collection_name}")
# 先获取所有数据(不带过滤)
all_results = await client.search(
tenant_id=tenant_id, # 这里会被转换为集合名
query_vector=query_vector,
limit=10,
score_threshold=0.01,
)
print(f" 原始命中: {len(all_results)}")
print()
# 3. 应用元数据过滤
print("[3] 应用元数据过滤...")
filtered_results = []
for hit in all_results:
payload = hit.get("payload", {})
hit_metadata = payload.get("metadata", {})
match = True
for field_key, condition in metadata_filter.items():
hit_value = hit_metadata.get(field_key)
if isinstance(condition, dict):
if "$eq" in condition:
if hit_value != condition["$eq"]:
match = False
break
elif "$in" in condition:
if hit_value not in condition["$in"]:
match = False
break
else:
if hit_value != condition:
match = False
break
if match:
filtered_results.append(hit)
print(f" 过滤后命中: {len(filtered_results)}")
print()
# 4. 显示结果
print("=" * 60)
print("检索结果")
print("=" * 60)
if filtered_results:
for i, hit in enumerate(filtered_results, 1):
payload = hit.get("payload", {})
metadata = payload.get("metadata", {})
text = payload.get("text", "")
score = hit.get("score", 0)
print(f"\n[{i}] 相似度: {score:.4f}")
print(f" 年级: {metadata.get('grade', 'N/A')}")
print(f" 学科: {metadata.get('subject', 'N/A')}")
print(f" 场景: {metadata.get('kb_scene', 'N/A')}")
print(f" 内容: {text}")
else:
print("\n未命中任何文档")
print("\n可能原因:")
print(" 1. 向量相似度太低")
print(" 2. 元数据不匹配")
print(" 3. 数据不在主集合中")
# 显示原始命中(用于调试)
print("\n原始命中(未过滤前):")
for i, hit in enumerate(all_results[:3], 1):
payload = hit.get("payload", {})
metadata = payload.get("metadata", {})
text = payload.get("text", "")[:60]
score = hit.get("score", 0)
print(f" [{i}] 分数:{score:.3f} 元数据:{metadata} {text}...")
async def test_different_filters():
"""测试不同过滤条件"""
tenant_id = "szmp@ash@2026"
kb_id = "30c19c84-8f69-4768-9d23-7f4a5bc3627a"
# 从 Qdrant 直接获取所有数据
client = await get_qdrant_client()
qdrant = await client.get_client()
collection_name = client.get_kb_collection_name(tenant_id, kb_id)
print("\n" + "=" * 60)
print("测试不同过滤条件")
print("=" * 60)
print(f"集合: {collection_name}")
print()
# 获取所有数据
results = await qdrant.scroll(
collection_name=collection_name,
limit=100,
with_payload=True,
)
all_hits = []
for point in results[0]:
all_hits.append({
"id": str(point.id),
"payload": point.payload or {}
})
print(f"总数据量: {len(all_hits)}")
print()
# 显示所有数据的元数据
print("所有数据的元数据:")
for hit in all_hits:
payload = hit["payload"]
metadata = payload.get("metadata", {})
text = payload.get("text", "")[:50]
print(f" - {metadata}: {text}...")
# 测试各种过滤条件
test_cases = [
{"name": "初二", "filter": {"grade": {"$eq": "初二"}}},
{"name": "痛点", "filter": {"kb_scene": {"$eq": "痛点"}}},
{"name": "初二+痛点", "filter": {"grade": {"$eq": "初二"}, "kb_scene": {"$eq": "痛点"}}},
{"name": "通用学科", "filter": {"subject": {"$eq": "通用"}}},
]
print("\n过滤测试:")
for case in test_cases:
filter_def = case["filter"]
# 应用过滤
filtered = []
for hit in all_hits:
payload = hit["payload"]
metadata = payload.get("metadata", {})
match = True
for field_key, condition in filter_def.items():
hit_value = metadata.get(field_key)
if isinstance(condition, dict) and "$eq" in condition:
if hit_value != condition["$eq"]:
match = False
break
else:
if hit_value != condition:
match = False
break
if match:
filtered.append(hit)
print(f" {case['name']}: {len(filtered)}/{len(all_hits)} 条命中")
async def main():
print("\n" + "=" * 60)
print("KB元数据过滤查询测试")
print("=" * 60 + "\n")
try:
await test_kb_search_with_filter()
await test_different_filters()
except Exception as e:
print(f"\n测试失败: {e}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
asyncio.run(main())