ai-robot-core/ai-service/test_kb_search_with_metadat...

218 lines
6.7 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
测试KB元数据过滤查询使用正确的知识库集合
"""
import asyncio
import json
from app.core.database import async_session_maker
from app.core.qdrant_client import get_qdrant_client
from app.services.embedding import get_embedding_provider
async def test_kb_search_with_filter():
"""测试带元数据过滤的KB检索"""
tenant_id = "szmp@ash@2026"
kb_id = "30c19c84-8f69-4768-9d23-7f4a5bc3627a" # 客服咨询知识库
query = "初二学生学习困难"
# 测试过滤器
metadata_filter = {
"grade": {"$eq": "初二"},
"kb_scene": {"$eq": "痛点"}
}
print("=" * 60)
print("测试带元数据过滤的KB检索")
print("=" * 60)
print(f"租户: {tenant_id}")
print(f"知识库: {kb_id}")
print(f"查询: {query}")
print(f"过滤器: {json.dumps(metadata_filter, ensure_ascii=False)}")
print()
# 1. 获取 embedding
print("[1] 生成查询向量...")
embedding_provider = await get_embedding_provider()
query_vector = await embedding_provider.embed(query)
print(f" 向量维度: {len(query_vector)}")
print()
# 2. 搜索 Qdrant
print("[2] 搜索 Qdrant...")
client = await get_qdrant_client()
# 使用知识库特定集合
collection_name = client.get_kb_collection_name(tenant_id, kb_id)
print(f" 集合: {collection_name}")
# 先获取所有数据(不带过滤)
all_results = await client.search(
tenant_id=tenant_id, # 这里会被转换为集合名
query_vector=query_vector,
limit=10,
score_threshold=0.01,
)
print(f" 原始命中: {len(all_results)}")
print()
# 3. 应用元数据过滤
print("[3] 应用元数据过滤...")
filtered_results = []
for hit in all_results:
payload = hit.get("payload", {})
hit_metadata = payload.get("metadata", {})
match = True
for field_key, condition in metadata_filter.items():
hit_value = hit_metadata.get(field_key)
if isinstance(condition, dict):
if "$eq" in condition:
if hit_value != condition["$eq"]:
match = False
break
elif "$in" in condition:
if hit_value not in condition["$in"]:
match = False
break
else:
if hit_value != condition:
match = False
break
if match:
filtered_results.append(hit)
print(f" 过滤后命中: {len(filtered_results)}")
print()
# 4. 显示结果
print("=" * 60)
print("检索结果")
print("=" * 60)
if filtered_results:
for i, hit in enumerate(filtered_results, 1):
payload = hit.get("payload", {})
metadata = payload.get("metadata", {})
text = payload.get("text", "")
score = hit.get("score", 0)
print(f"\n[{i}] 相似度: {score:.4f}")
print(f" 年级: {metadata.get('grade', 'N/A')}")
print(f" 学科: {metadata.get('subject', 'N/A')}")
print(f" 场景: {metadata.get('kb_scene', 'N/A')}")
print(f" 内容: {text}")
else:
print("\n未命中任何文档")
print("\n可能原因:")
print(" 1. 向量相似度太低")
print(" 2. 元数据不匹配")
print(" 3. 数据不在主集合中")
# 显示原始命中(用于调试)
print("\n原始命中(未过滤前):")
for i, hit in enumerate(all_results[:3], 1):
payload = hit.get("payload", {})
metadata = payload.get("metadata", {})
text = payload.get("text", "")[:60]
score = hit.get("score", 0)
print(f" [{i}] 分数:{score:.3f} 元数据:{metadata} {text}...")
async def test_different_filters():
"""测试不同过滤条件"""
tenant_id = "szmp@ash@2026"
kb_id = "30c19c84-8f69-4768-9d23-7f4a5bc3627a"
# 从 Qdrant 直接获取所有数据
client = await get_qdrant_client()
qdrant = await client.get_client()
collection_name = client.get_kb_collection_name(tenant_id, kb_id)
print("\n" + "=" * 60)
print("测试不同过滤条件")
print("=" * 60)
print(f"集合: {collection_name}")
print()
# 获取所有数据
results = await qdrant.scroll(
collection_name=collection_name,
limit=100,
with_payload=True,
)
all_hits = []
for point in results[0]:
all_hits.append({
"id": str(point.id),
"payload": point.payload or {}
})
print(f"总数据量: {len(all_hits)}")
print()
# 显示所有数据的元数据
print("所有数据的元数据:")
for hit in all_hits:
payload = hit["payload"]
metadata = payload.get("metadata", {})
text = payload.get("text", "")[:50]
print(f" - {metadata}: {text}...")
# 测试各种过滤条件
test_cases = [
{"name": "初二", "filter": {"grade": {"$eq": "初二"}}},
{"name": "痛点", "filter": {"kb_scene": {"$eq": "痛点"}}},
{"name": "初二+痛点", "filter": {"grade": {"$eq": "初二"}, "kb_scene": {"$eq": "痛点"}}},
{"name": "通用学科", "filter": {"subject": {"$eq": "通用"}}},
]
print("\n过滤测试:")
for case in test_cases:
filter_def = case["filter"]
# 应用过滤
filtered = []
for hit in all_hits:
payload = hit["payload"]
metadata = payload.get("metadata", {})
match = True
for field_key, condition in filter_def.items():
hit_value = metadata.get(field_key)
if isinstance(condition, dict) and "$eq" in condition:
if hit_value != condition["$eq"]:
match = False
break
else:
if hit_value != condition:
match = False
break
if match:
filtered.append(hit)
print(f" {case['name']}: {len(filtered)}/{len(all_hits)} 条命中")
async def main():
print("\n" + "=" * 60)
print("KB元数据过滤查询测试")
print("=" * 60 + "\n")
try:
await test_kb_search_with_filter()
await test_different_filters()
except Exception as e:
print(f"\n测试失败: {e}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
asyncio.run(main())