167 lines
5.5 KiB
Python
167 lines
5.5 KiB
Python
"""
|
||
测试KB元数据过滤查询
|
||
"""
|
||
import asyncio
|
||
import json
|
||
from app.core.database import async_session_maker
|
||
from app.services.mid.metadata_filter_builder import MetadataFilterBuilder
|
||
from app.services.mid.default_kb_tool_runner import DefaultKbToolRunner
|
||
from app.core.qdrant_client import get_qdrant_client
|
||
|
||
|
||
async def test_metadata_filter():
|
||
"""测试元数据过滤器构建"""
|
||
tenant_id = "szmp@ash@2026"
|
||
|
||
# 测试上下文 - 模拟用户查询"初二数学痛点"
|
||
test_context = {
|
||
"grade": "初二",
|
||
"subject": "通用",
|
||
"kb_scene": "痛点"
|
||
}
|
||
|
||
print("=" * 60)
|
||
print("测试元数据过滤器构建")
|
||
print("=" * 60)
|
||
print(f"租户: {tenant_id}")
|
||
print(f"查询上下文: {json.dumps(test_context, ensure_ascii=False)}")
|
||
print()
|
||
|
||
async with async_session_maker() as session:
|
||
# 1. 测试过滤器构建
|
||
filter_builder = MetadataFilterBuilder(session)
|
||
result = await filter_builder.build_filter(tenant_id, test_context)
|
||
|
||
print("过滤器构建结果:")
|
||
print(f" 成功: {result.success}")
|
||
print(f" 应用的过滤器: {json.dumps(result.applied_filter, ensure_ascii=False, indent=2)}")
|
||
print(f" 缺失的必填字段: {result.missing_required_slots}")
|
||
print(f" 调试信息: {json.dumps(result.debug_info, ensure_ascii=False, indent=2)}")
|
||
print()
|
||
|
||
# 2. 获取可过滤字段列表
|
||
filter_schema = await filter_builder.get_filter_schema(tenant_id)
|
||
print("可过滤字段配置:")
|
||
for field in filter_schema:
|
||
print(f" - {field['field_key']}: {field['label']} (类型: {field['type']}, 必填: {field['required']})")
|
||
if field['options']:
|
||
print(f" 选项: {field['options']}")
|
||
print()
|
||
|
||
|
||
async def test_kb_search():
|
||
"""测试KB向量检索(带元数据过滤)"""
|
||
tenant_id = "szmp@ash@2026"
|
||
kb_id = "your_kb_id" # 需要替换为实际的知识库ID
|
||
|
||
# 测试查询
|
||
query = "初二学生数学学习有什么困难"
|
||
|
||
# 测试上下文
|
||
context = {
|
||
"grade": "初二",
|
||
"subject": "数学",
|
||
"kb_scene": "痛点"
|
||
}
|
||
|
||
print("=" * 60)
|
||
print("测试KB向量检索(带元数据过滤)")
|
||
print("=" * 60)
|
||
print(f"租户: {tenant_id}")
|
||
print(f"知识库: {kb_id}")
|
||
print(f"查询: {query}")
|
||
print(f"上下文: {json.dumps(context, ensure_ascii=False)}")
|
||
print()
|
||
|
||
async with async_session_maker() as session:
|
||
# 1. 先构建过滤器
|
||
filter_builder = MetadataFilterBuilder(session)
|
||
filter_result = await filter_builder.build_filter(tenant_id, context)
|
||
|
||
print(f"过滤器: {json.dumps(filter_result.applied_filter, ensure_ascii=False)}")
|
||
print()
|
||
|
||
# 执行检索 - 使用更长的超时时间
|
||
from app.services.mid.timeout_governor import TimeoutGovernor
|
||
from app.services.mid.default_kb_tool_runner import KbToolConfig
|
||
|
||
config = KbToolConfig(
|
||
enabled=True,
|
||
top_k=5,
|
||
timeout_ms=10000, # 10秒超时
|
||
min_score_threshold=0.5,
|
||
)
|
||
kb_runner = DefaultKbToolRunner(
|
||
timeout_governor=TimeoutGovernor(),
|
||
config=config,
|
||
)
|
||
|
||
# 获取可用的KB列表
|
||
from app.services.knowledge_base_service import KnowledgeBaseService
|
||
kb_service = KnowledgeBaseService(session)
|
||
kbs = await kb_service.list_knowledge_bases(tenant_id)
|
||
|
||
if not kbs:
|
||
print("未找到知识库,请先创建知识库并上传文档")
|
||
return
|
||
|
||
print(f"找到 {len(kbs)} 个知识库:")
|
||
for kb in kbs:
|
||
print(f" - {kb.name} (ID: {kb.id})")
|
||
print()
|
||
|
||
# 使用第一个知识库进行测试
|
||
test_kb_id = str(kbs[0].id)
|
||
print(f"使用知识库: {kbs[0].name} (ID: {test_kb_id})")
|
||
print()
|
||
|
||
# 执行检索
|
||
result = await kb_runner.execute(
|
||
tenant_id=tenant_id,
|
||
query=query,
|
||
metadata_filter=filter_result.applied_filter
|
||
)
|
||
|
||
print("检索结果:")
|
||
print(f" 成功: {result.success}")
|
||
print(f" 命中数: {len(result.hits)}")
|
||
print(f" 回退原因: {result.fallback_reason_code}")
|
||
print()
|
||
|
||
if result.hits:
|
||
print("命中文档:")
|
||
for i, hit in enumerate(result.hits, 1):
|
||
print(f"\n [{i}] 分数: {hit.score:.4f}")
|
||
print(f" 内容: {hit.text[:200]}...")
|
||
print(f" 元数据: {json.dumps(hit.metadata, ensure_ascii=False)}")
|
||
else:
|
||
print("未命中任何文档")
|
||
print("\n可能原因:")
|
||
print(" 1. 知识库中没有匹配的文档")
|
||
print(" 2. 元数据过滤器过于严格")
|
||
print(" 3. 向量相似度阈值过高")
|
||
|
||
|
||
async def main():
|
||
print("\n" + "=" * 60)
|
||
print("KB元数据过滤查询测试")
|
||
print("=" * 60 + "\n")
|
||
|
||
try:
|
||
# 测试1: 过滤器构建
|
||
await test_metadata_filter()
|
||
|
||
print("\n" + "=" * 60 + "\n")
|
||
|
||
# 测试2: 向量检索
|
||
await test_kb_search()
|
||
|
||
except Exception as e:
|
||
print(f"\n测试失败: {e}")
|
||
import traceback
|
||
print(traceback.format_exc())
|
||
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|