ai-robot-core/ai-service/test_kb_metadata_search.py

167 lines
5.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
测试KB元数据过滤查询
"""
import asyncio
import json
from app.core.database import async_session_maker
from app.services.mid.metadata_filter_builder import MetadataFilterBuilder
from app.services.mid.default_kb_tool_runner import DefaultKbToolRunner
from app.core.qdrant_client import get_qdrant_client
async def test_metadata_filter():
"""测试元数据过滤器构建"""
tenant_id = "szmp@ash@2026"
# 测试上下文 - 模拟用户查询"初二数学痛点"
test_context = {
"grade": "初二",
"subject": "通用",
"kb_scene": "痛点"
}
print("=" * 60)
print("测试元数据过滤器构建")
print("=" * 60)
print(f"租户: {tenant_id}")
print(f"查询上下文: {json.dumps(test_context, ensure_ascii=False)}")
print()
async with async_session_maker() as session:
# 1. 测试过滤器构建
filter_builder = MetadataFilterBuilder(session)
result = await filter_builder.build_filter(tenant_id, test_context)
print("过滤器构建结果:")
print(f" 成功: {result.success}")
print(f" 应用的过滤器: {json.dumps(result.applied_filter, ensure_ascii=False, indent=2)}")
print(f" 缺失的必填字段: {result.missing_required_slots}")
print(f" 调试信息: {json.dumps(result.debug_info, ensure_ascii=False, indent=2)}")
print()
# 2. 获取可过滤字段列表
filter_schema = await filter_builder.get_filter_schema(tenant_id)
print("可过滤字段配置:")
for field in filter_schema:
print(f" - {field['field_key']}: {field['label']} (类型: {field['type']}, 必填: {field['required']})")
if field['options']:
print(f" 选项: {field['options']}")
print()
async def test_kb_search():
"""测试KB向量检索带元数据过滤"""
tenant_id = "szmp@ash@2026"
kb_id = "your_kb_id" # 需要替换为实际的知识库ID
# 测试查询
query = "初二学生数学学习有什么困难"
# 测试上下文
context = {
"grade": "初二",
"subject": "数学",
"kb_scene": "痛点"
}
print("=" * 60)
print("测试KB向量检索带元数据过滤")
print("=" * 60)
print(f"租户: {tenant_id}")
print(f"知识库: {kb_id}")
print(f"查询: {query}")
print(f"上下文: {json.dumps(context, ensure_ascii=False)}")
print()
async with async_session_maker() as session:
# 1. 先构建过滤器
filter_builder = MetadataFilterBuilder(session)
filter_result = await filter_builder.build_filter(tenant_id, context)
print(f"过滤器: {json.dumps(filter_result.applied_filter, ensure_ascii=False)}")
print()
# 执行检索 - 使用更长的超时时间
from app.services.mid.timeout_governor import TimeoutGovernor
from app.services.mid.default_kb_tool_runner import KbToolConfig
config = KbToolConfig(
enabled=True,
top_k=5,
timeout_ms=10000, # 10秒超时
min_score_threshold=0.5,
)
kb_runner = DefaultKbToolRunner(
timeout_governor=TimeoutGovernor(),
config=config,
)
# 获取可用的KB列表
from app.services.knowledge_base_service import KnowledgeBaseService
kb_service = KnowledgeBaseService(session)
kbs = await kb_service.list_knowledge_bases(tenant_id)
if not kbs:
print("未找到知识库,请先创建知识库并上传文档")
return
print(f"找到 {len(kbs)} 个知识库:")
for kb in kbs:
print(f" - {kb.name} (ID: {kb.id})")
print()
# 使用第一个知识库进行测试
test_kb_id = str(kbs[0].id)
print(f"使用知识库: {kbs[0].name} (ID: {test_kb_id})")
print()
# 执行检索
result = await kb_runner.execute(
tenant_id=tenant_id,
query=query,
metadata_filter=filter_result.applied_filter
)
print("检索结果:")
print(f" 成功: {result.success}")
print(f" 命中数: {len(result.hits)}")
print(f" 回退原因: {result.fallback_reason_code}")
print()
if result.hits:
print("命中文档:")
for i, hit in enumerate(result.hits, 1):
print(f"\n [{i}] 分数: {hit.score:.4f}")
print(f" 内容: {hit.text[:200]}...")
print(f" 元数据: {json.dumps(hit.metadata, ensure_ascii=False)}")
else:
print("未命中任何文档")
print("\n可能原因:")
print(" 1. 知识库中没有匹配的文档")
print(" 2. 元数据过滤器过于严格")
print(" 3. 向量相似度阈值过高")
async def main():
print("\n" + "=" * 60)
print("KB元数据过滤查询测试")
print("=" * 60 + "\n")
try:
# 测试1: 过滤器构建
await test_metadata_filter()
print("\n" + "=" * 60 + "\n")
# 测试2: 向量检索
await test_kb_search()
except Exception as e:
print(f"\n测试失败: {e}")
import traceback
print(traceback.format_exc())
if __name__ == "__main__":
asyncio.run(main())