""" 测试KB元数据过滤查询(使用正确的知识库集合) """ import asyncio import json from app.core.database import async_session_maker from app.core.qdrant_client import get_qdrant_client from app.services.embedding import get_embedding_provider async def test_kb_search_with_filter(): """测试带元数据过滤的KB检索""" tenant_id = "szmp@ash@2026" kb_id = "30c19c84-8f69-4768-9d23-7f4a5bc3627a" # 客服咨询知识库 query = "初二学生学习困难" # 测试过滤器 metadata_filter = { "grade": {"$eq": "初二"}, "kb_scene": {"$eq": "痛点"} } print("=" * 60) print("测试带元数据过滤的KB检索") print("=" * 60) print(f"租户: {tenant_id}") print(f"知识库: {kb_id}") print(f"查询: {query}") print(f"过滤器: {json.dumps(metadata_filter, ensure_ascii=False)}") print() # 1. 获取 embedding print("[1] 生成查询向量...") embedding_provider = await get_embedding_provider() query_vector = await embedding_provider.embed(query) print(f" 向量维度: {len(query_vector)}") print() # 2. 搜索 Qdrant print("[2] 搜索 Qdrant...") client = await get_qdrant_client() # 使用知识库特定集合 collection_name = client.get_kb_collection_name(tenant_id, kb_id) print(f" 集合: {collection_name}") # 先获取所有数据(不带过滤) all_results = await client.search( tenant_id=tenant_id, # 这里会被转换为集合名 query_vector=query_vector, limit=10, score_threshold=0.01, ) print(f" 原始命中: {len(all_results)} 条") print() # 3. 应用元数据过滤 print("[3] 应用元数据过滤...") filtered_results = [] for hit in all_results: payload = hit.get("payload", {}) hit_metadata = payload.get("metadata", {}) match = True for field_key, condition in metadata_filter.items(): hit_value = hit_metadata.get(field_key) if isinstance(condition, dict): if "$eq" in condition: if hit_value != condition["$eq"]: match = False break elif "$in" in condition: if hit_value not in condition["$in"]: match = False break else: if hit_value != condition: match = False break if match: filtered_results.append(hit) print(f" 过滤后命中: {len(filtered_results)} 条") print() # 4. 显示结果 print("=" * 60) print("检索结果") print("=" * 60) if filtered_results: for i, hit in enumerate(filtered_results, 1): payload = hit.get("payload", {}) metadata = payload.get("metadata", {}) text = payload.get("text", "") score = hit.get("score", 0) print(f"\n[{i}] 相似度: {score:.4f}") print(f" 年级: {metadata.get('grade', 'N/A')}") print(f" 学科: {metadata.get('subject', 'N/A')}") print(f" 场景: {metadata.get('kb_scene', 'N/A')}") print(f" 内容: {text}") else: print("\n未命中任何文档") print("\n可能原因:") print(" 1. 向量相似度太低") print(" 2. 元数据不匹配") print(" 3. 数据不在主集合中") # 显示原始命中(用于调试) print("\n原始命中(未过滤前):") for i, hit in enumerate(all_results[:3], 1): payload = hit.get("payload", {}) metadata = payload.get("metadata", {}) text = payload.get("text", "")[:60] score = hit.get("score", 0) print(f" [{i}] 分数:{score:.3f} 元数据:{metadata} {text}...") async def test_different_filters(): """测试不同过滤条件""" tenant_id = "szmp@ash@2026" kb_id = "30c19c84-8f69-4768-9d23-7f4a5bc3627a" # 从 Qdrant 直接获取所有数据 client = await get_qdrant_client() qdrant = await client.get_client() collection_name = client.get_kb_collection_name(tenant_id, kb_id) print("\n" + "=" * 60) print("测试不同过滤条件") print("=" * 60) print(f"集合: {collection_name}") print() # 获取所有数据 results = await qdrant.scroll( collection_name=collection_name, limit=100, with_payload=True, ) all_hits = [] for point in results[0]: all_hits.append({ "id": str(point.id), "payload": point.payload or {} }) print(f"总数据量: {len(all_hits)} 条") print() # 显示所有数据的元数据 print("所有数据的元数据:") for hit in all_hits: payload = hit["payload"] metadata = payload.get("metadata", {}) text = payload.get("text", "")[:50] print(f" - {metadata}: {text}...") # 测试各种过滤条件 test_cases = [ {"name": "初二", "filter": {"grade": {"$eq": "初二"}}}, {"name": "痛点", "filter": {"kb_scene": {"$eq": "痛点"}}}, {"name": "初二+痛点", "filter": {"grade": {"$eq": "初二"}, "kb_scene": {"$eq": "痛点"}}}, {"name": "通用学科", "filter": {"subject": {"$eq": "通用"}}}, ] print("\n过滤测试:") for case in test_cases: filter_def = case["filter"] # 应用过滤 filtered = [] for hit in all_hits: payload = hit["payload"] metadata = payload.get("metadata", {}) match = True for field_key, condition in filter_def.items(): hit_value = metadata.get(field_key) if isinstance(condition, dict) and "$eq" in condition: if hit_value != condition["$eq"]: match = False break else: if hit_value != condition: match = False break if match: filtered.append(hit) print(f" {case['name']}: {len(filtered)}/{len(all_hits)} 条命中") async def main(): print("\n" + "=" * 60) print("KB元数据过滤查询测试") print("=" * 60 + "\n") try: await test_kb_search_with_filter() await test_different_filters() except Exception as e: print(f"\n测试失败: {e}") import traceback print(traceback.format_exc()) if __name__ == "__main__": asyncio.run(main())