218 lines
6.7 KiB
Python
218 lines
6.7 KiB
Python
|
|
"""
|
|||
|
|
测试KB元数据过滤查询(使用正确的知识库集合)
|
|||
|
|
"""
|
|||
|
|
import asyncio
|
|||
|
|
import json
|
|||
|
|
from app.core.database import async_session_maker
|
|||
|
|
from app.core.qdrant_client import get_qdrant_client
|
|||
|
|
from app.services.embedding import get_embedding_provider
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def test_kb_search_with_filter():
|
|||
|
|
"""测试带元数据过滤的KB检索"""
|
|||
|
|
tenant_id = "szmp@ash@2026"
|
|||
|
|
kb_id = "30c19c84-8f69-4768-9d23-7f4a5bc3627a" # 客服咨询知识库
|
|||
|
|
|
|||
|
|
query = "初二学生学习困难"
|
|||
|
|
|
|||
|
|
# 测试过滤器
|
|||
|
|
metadata_filter = {
|
|||
|
|
"grade": {"$eq": "初二"},
|
|||
|
|
"kb_scene": {"$eq": "痛点"}
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
print("=" * 60)
|
|||
|
|
print("测试带元数据过滤的KB检索")
|
|||
|
|
print("=" * 60)
|
|||
|
|
print(f"租户: {tenant_id}")
|
|||
|
|
print(f"知识库: {kb_id}")
|
|||
|
|
print(f"查询: {query}")
|
|||
|
|
print(f"过滤器: {json.dumps(metadata_filter, ensure_ascii=False)}")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
# 1. 获取 embedding
|
|||
|
|
print("[1] 生成查询向量...")
|
|||
|
|
embedding_provider = await get_embedding_provider()
|
|||
|
|
query_vector = await embedding_provider.embed(query)
|
|||
|
|
print(f" 向量维度: {len(query_vector)}")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
# 2. 搜索 Qdrant
|
|||
|
|
print("[2] 搜索 Qdrant...")
|
|||
|
|
client = await get_qdrant_client()
|
|||
|
|
|
|||
|
|
# 使用知识库特定集合
|
|||
|
|
collection_name = client.get_kb_collection_name(tenant_id, kb_id)
|
|||
|
|
print(f" 集合: {collection_name}")
|
|||
|
|
|
|||
|
|
# 先获取所有数据(不带过滤)
|
|||
|
|
all_results = await client.search(
|
|||
|
|
tenant_id=tenant_id, # 这里会被转换为集合名
|
|||
|
|
query_vector=query_vector,
|
|||
|
|
limit=10,
|
|||
|
|
score_threshold=0.01,
|
|||
|
|
)
|
|||
|
|
print(f" 原始命中: {len(all_results)} 条")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
# 3. 应用元数据过滤
|
|||
|
|
print("[3] 应用元数据过滤...")
|
|||
|
|
filtered_results = []
|
|||
|
|
for hit in all_results:
|
|||
|
|
payload = hit.get("payload", {})
|
|||
|
|
hit_metadata = payload.get("metadata", {})
|
|||
|
|
|
|||
|
|
match = True
|
|||
|
|
for field_key, condition in metadata_filter.items():
|
|||
|
|
hit_value = hit_metadata.get(field_key)
|
|||
|
|
|
|||
|
|
if isinstance(condition, dict):
|
|||
|
|
if "$eq" in condition:
|
|||
|
|
if hit_value != condition["$eq"]:
|
|||
|
|
match = False
|
|||
|
|
break
|
|||
|
|
elif "$in" in condition:
|
|||
|
|
if hit_value not in condition["$in"]:
|
|||
|
|
match = False
|
|||
|
|
break
|
|||
|
|
else:
|
|||
|
|
if hit_value != condition:
|
|||
|
|
match = False
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if match:
|
|||
|
|
filtered_results.append(hit)
|
|||
|
|
|
|||
|
|
print(f" 过滤后命中: {len(filtered_results)} 条")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
# 4. 显示结果
|
|||
|
|
print("=" * 60)
|
|||
|
|
print("检索结果")
|
|||
|
|
print("=" * 60)
|
|||
|
|
|
|||
|
|
if filtered_results:
|
|||
|
|
for i, hit in enumerate(filtered_results, 1):
|
|||
|
|
payload = hit.get("payload", {})
|
|||
|
|
metadata = payload.get("metadata", {})
|
|||
|
|
text = payload.get("text", "")
|
|||
|
|
score = hit.get("score", 0)
|
|||
|
|
|
|||
|
|
print(f"\n[{i}] 相似度: {score:.4f}")
|
|||
|
|
print(f" 年级: {metadata.get('grade', 'N/A')}")
|
|||
|
|
print(f" 学科: {metadata.get('subject', 'N/A')}")
|
|||
|
|
print(f" 场景: {metadata.get('kb_scene', 'N/A')}")
|
|||
|
|
print(f" 内容: {text}")
|
|||
|
|
else:
|
|||
|
|
print("\n未命中任何文档")
|
|||
|
|
print("\n可能原因:")
|
|||
|
|
print(" 1. 向量相似度太低")
|
|||
|
|
print(" 2. 元数据不匹配")
|
|||
|
|
print(" 3. 数据不在主集合中")
|
|||
|
|
|
|||
|
|
# 显示原始命中(用于调试)
|
|||
|
|
print("\n原始命中(未过滤前):")
|
|||
|
|
for i, hit in enumerate(all_results[:3], 1):
|
|||
|
|
payload = hit.get("payload", {})
|
|||
|
|
metadata = payload.get("metadata", {})
|
|||
|
|
text = payload.get("text", "")[:60]
|
|||
|
|
score = hit.get("score", 0)
|
|||
|
|
print(f" [{i}] 分数:{score:.3f} 元数据:{metadata} {text}...")
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def test_different_filters():
|
|||
|
|
"""测试不同过滤条件"""
|
|||
|
|
tenant_id = "szmp@ash@2026"
|
|||
|
|
kb_id = "30c19c84-8f69-4768-9d23-7f4a5bc3627a"
|
|||
|
|
|
|||
|
|
# 从 Qdrant 直接获取所有数据
|
|||
|
|
client = await get_qdrant_client()
|
|||
|
|
qdrant = await client.get_client()
|
|||
|
|
collection_name = client.get_kb_collection_name(tenant_id, kb_id)
|
|||
|
|
|
|||
|
|
print("\n" + "=" * 60)
|
|||
|
|
print("测试不同过滤条件")
|
|||
|
|
print("=" * 60)
|
|||
|
|
print(f"集合: {collection_name}")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
# 获取所有数据
|
|||
|
|
results = await qdrant.scroll(
|
|||
|
|
collection_name=collection_name,
|
|||
|
|
limit=100,
|
|||
|
|
with_payload=True,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
all_hits = []
|
|||
|
|
for point in results[0]:
|
|||
|
|
all_hits.append({
|
|||
|
|
"id": str(point.id),
|
|||
|
|
"payload": point.payload or {}
|
|||
|
|
})
|
|||
|
|
|
|||
|
|
print(f"总数据量: {len(all_hits)} 条")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
# 显示所有数据的元数据
|
|||
|
|
print("所有数据的元数据:")
|
|||
|
|
for hit in all_hits:
|
|||
|
|
payload = hit["payload"]
|
|||
|
|
metadata = payload.get("metadata", {})
|
|||
|
|
text = payload.get("text", "")[:50]
|
|||
|
|
print(f" - {metadata}: {text}...")
|
|||
|
|
|
|||
|
|
# 测试各种过滤条件
|
|||
|
|
test_cases = [
|
|||
|
|
{"name": "初二", "filter": {"grade": {"$eq": "初二"}}},
|
|||
|
|
{"name": "痛点", "filter": {"kb_scene": {"$eq": "痛点"}}},
|
|||
|
|
{"name": "初二+痛点", "filter": {"grade": {"$eq": "初二"}, "kb_scene": {"$eq": "痛点"}}},
|
|||
|
|
{"name": "通用学科", "filter": {"subject": {"$eq": "通用"}}},
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
print("\n过滤测试:")
|
|||
|
|
for case in test_cases:
|
|||
|
|
filter_def = case["filter"]
|
|||
|
|
|
|||
|
|
# 应用过滤
|
|||
|
|
filtered = []
|
|||
|
|
for hit in all_hits:
|
|||
|
|
payload = hit["payload"]
|
|||
|
|
metadata = payload.get("metadata", {})
|
|||
|
|
|
|||
|
|
match = True
|
|||
|
|
for field_key, condition in filter_def.items():
|
|||
|
|
hit_value = metadata.get(field_key)
|
|||
|
|
|
|||
|
|
if isinstance(condition, dict) and "$eq" in condition:
|
|||
|
|
if hit_value != condition["$eq"]:
|
|||
|
|
match = False
|
|||
|
|
break
|
|||
|
|
else:
|
|||
|
|
if hit_value != condition:
|
|||
|
|
match = False
|
|||
|
|
break
|
|||
|
|
|
|||
|
|
if match:
|
|||
|
|
filtered.append(hit)
|
|||
|
|
|
|||
|
|
print(f" {case['name']}: {len(filtered)}/{len(all_hits)} 条命中")
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def main():
|
|||
|
|
print("\n" + "=" * 60)
|
|||
|
|
print("KB元数据过滤查询测试")
|
|||
|
|
print("=" * 60 + "\n")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
await test_kb_search_with_filter()
|
|||
|
|
await test_different_filters()
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"\n测试失败: {e}")
|
|||
|
|
import traceback
|
|||
|
|
print(traceback.format_exc())
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
asyncio.run(main())
|