167 lines
5.5 KiB
Python
167 lines
5.5 KiB
Python
|
|
"""
|
|||
|
|
测试KB元数据过滤查询
|
|||
|
|
"""
|
|||
|
|
import asyncio
|
|||
|
|
import json
|
|||
|
|
from app.core.database import async_session_maker
|
|||
|
|
from app.services.mid.metadata_filter_builder import MetadataFilterBuilder
|
|||
|
|
from app.services.mid.default_kb_tool_runner import DefaultKbToolRunner
|
|||
|
|
from app.core.qdrant_client import get_qdrant_client
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def test_metadata_filter():
|
|||
|
|
"""测试元数据过滤器构建"""
|
|||
|
|
tenant_id = "szmp@ash@2026"
|
|||
|
|
|
|||
|
|
# 测试上下文 - 模拟用户查询"初二数学痛点"
|
|||
|
|
test_context = {
|
|||
|
|
"grade": "初二",
|
|||
|
|
"subject": "通用",
|
|||
|
|
"kb_scene": "痛点"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
print("=" * 60)
|
|||
|
|
print("测试元数据过滤器构建")
|
|||
|
|
print("=" * 60)
|
|||
|
|
print(f"租户: {tenant_id}")
|
|||
|
|
print(f"查询上下文: {json.dumps(test_context, ensure_ascii=False)}")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
async with async_session_maker() as session:
|
|||
|
|
# 1. 测试过滤器构建
|
|||
|
|
filter_builder = MetadataFilterBuilder(session)
|
|||
|
|
result = await filter_builder.build_filter(tenant_id, test_context)
|
|||
|
|
|
|||
|
|
print("过滤器构建结果:")
|
|||
|
|
print(f" 成功: {result.success}")
|
|||
|
|
print(f" 应用的过滤器: {json.dumps(result.applied_filter, ensure_ascii=False, indent=2)}")
|
|||
|
|
print(f" 缺失的必填字段: {result.missing_required_slots}")
|
|||
|
|
print(f" 调试信息: {json.dumps(result.debug_info, ensure_ascii=False, indent=2)}")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
# 2. 获取可过滤字段列表
|
|||
|
|
filter_schema = await filter_builder.get_filter_schema(tenant_id)
|
|||
|
|
print("可过滤字段配置:")
|
|||
|
|
for field in filter_schema:
|
|||
|
|
print(f" - {field['field_key']}: {field['label']} (类型: {field['type']}, 必填: {field['required']})")
|
|||
|
|
if field['options']:
|
|||
|
|
print(f" 选项: {field['options']}")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def test_kb_search():
|
|||
|
|
"""测试KB向量检索(带元数据过滤)"""
|
|||
|
|
tenant_id = "szmp@ash@2026"
|
|||
|
|
kb_id = "your_kb_id" # 需要替换为实际的知识库ID
|
|||
|
|
|
|||
|
|
# 测试查询
|
|||
|
|
query = "初二学生数学学习有什么困难"
|
|||
|
|
|
|||
|
|
# 测试上下文
|
|||
|
|
context = {
|
|||
|
|
"grade": "初二",
|
|||
|
|
"subject": "数学",
|
|||
|
|
"kb_scene": "痛点"
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
print("=" * 60)
|
|||
|
|
print("测试KB向量检索(带元数据过滤)")
|
|||
|
|
print("=" * 60)
|
|||
|
|
print(f"租户: {tenant_id}")
|
|||
|
|
print(f"知识库: {kb_id}")
|
|||
|
|
print(f"查询: {query}")
|
|||
|
|
print(f"上下文: {json.dumps(context, ensure_ascii=False)}")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
async with async_session_maker() as session:
|
|||
|
|
# 1. 先构建过滤器
|
|||
|
|
filter_builder = MetadataFilterBuilder(session)
|
|||
|
|
filter_result = await filter_builder.build_filter(tenant_id, context)
|
|||
|
|
|
|||
|
|
print(f"过滤器: {json.dumps(filter_result.applied_filter, ensure_ascii=False)}")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
# 执行检索 - 使用更长的超时时间
|
|||
|
|
from app.services.mid.timeout_governor import TimeoutGovernor
|
|||
|
|
from app.services.mid.default_kb_tool_runner import KbToolConfig
|
|||
|
|
|
|||
|
|
config = KbToolConfig(
|
|||
|
|
enabled=True,
|
|||
|
|
top_k=5,
|
|||
|
|
timeout_ms=10000, # 10秒超时
|
|||
|
|
min_score_threshold=0.5,
|
|||
|
|
)
|
|||
|
|
kb_runner = DefaultKbToolRunner(
|
|||
|
|
timeout_governor=TimeoutGovernor(),
|
|||
|
|
config=config,
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
# 获取可用的KB列表
|
|||
|
|
from app.services.knowledge_base_service import KnowledgeBaseService
|
|||
|
|
kb_service = KnowledgeBaseService(session)
|
|||
|
|
kbs = await kb_service.list_knowledge_bases(tenant_id)
|
|||
|
|
|
|||
|
|
if not kbs:
|
|||
|
|
print("未找到知识库,请先创建知识库并上传文档")
|
|||
|
|
return
|
|||
|
|
|
|||
|
|
print(f"找到 {len(kbs)} 个知识库:")
|
|||
|
|
for kb in kbs:
|
|||
|
|
print(f" - {kb.name} (ID: {kb.id})")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
# 使用第一个知识库进行测试
|
|||
|
|
test_kb_id = str(kbs[0].id)
|
|||
|
|
print(f"使用知识库: {kbs[0].name} (ID: {test_kb_id})")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
# 执行检索
|
|||
|
|
result = await kb_runner.execute(
|
|||
|
|
tenant_id=tenant_id,
|
|||
|
|
query=query,
|
|||
|
|
metadata_filter=filter_result.applied_filter
|
|||
|
|
)
|
|||
|
|
|
|||
|
|
print("检索结果:")
|
|||
|
|
print(f" 成功: {result.success}")
|
|||
|
|
print(f" 命中数: {len(result.hits)}")
|
|||
|
|
print(f" 回退原因: {result.fallback_reason_code}")
|
|||
|
|
print()
|
|||
|
|
|
|||
|
|
if result.hits:
|
|||
|
|
print("命中文档:")
|
|||
|
|
for i, hit in enumerate(result.hits, 1):
|
|||
|
|
print(f"\n [{i}] 分数: {hit.score:.4f}")
|
|||
|
|
print(f" 内容: {hit.text[:200]}...")
|
|||
|
|
print(f" 元数据: {json.dumps(hit.metadata, ensure_ascii=False)}")
|
|||
|
|
else:
|
|||
|
|
print("未命中任何文档")
|
|||
|
|
print("\n可能原因:")
|
|||
|
|
print(" 1. 知识库中没有匹配的文档")
|
|||
|
|
print(" 2. 元数据过滤器过于严格")
|
|||
|
|
print(" 3. 向量相似度阈值过高")
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def main():
|
|||
|
|
print("\n" + "=" * 60)
|
|||
|
|
print("KB元数据过滤查询测试")
|
|||
|
|
print("=" * 60 + "\n")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
# 测试1: 过滤器构建
|
|||
|
|
await test_metadata_filter()
|
|||
|
|
|
|||
|
|
print("\n" + "=" * 60 + "\n")
|
|||
|
|
|
|||
|
|
# 测试2: 向量检索
|
|||
|
|
await test_kb_search()
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"\n测试失败: {e}")
|
|||
|
|
import traceback
|
|||
|
|
print(traceback.format_exc())
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
asyncio.run(main())
|