ai-robot-core/ai-service/scripts/test_kb_search.py

94 lines
3.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
测试 kb_search_dynamic 工具是否能用给定的参数查出数据
"""
import asyncio
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool
from app.core.config import get_settings
async def test_kb_search():
"""测试知识库搜索"""
settings = get_settings()
# 创建数据库会话
engine = create_async_engine(settings.database_url)
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
async with async_session() as session:
tool = KbSearchDynamicTool(session=session)
# 测试参数
test_cases = [
{
"name": "完整参数含context过滤",
"params": {
"query": "三年级语文学习",
"tenant_id": "szmp@ash@2026",
"scene": "学习方案",
"top_k": 5,
"context": {"grade": "三年级", "subject": "语文"},
}
},
{
"name": "简化参数无context",
"params": {
"query": "三年级语文学习",
"tenant_id": "szmp@ash@2026",
"scene": "学习方案",
"top_k": 5,
}
},
{
"name": "仅query和tenant_id",
"params": {
"query": "三年级语文学习",
"tenant_id": "szmp@ash@2026",
"top_k": 5,
}
},
]
for test_case in test_cases:
print(f"\n{'='*80}")
print(f"测试: {test_case['name']}")
print(f"{'='*80}")
print(f"参数: {test_case['params']}")
try:
result = await tool.execute(**test_case['params'])
print(f"\n结果:")
print(f" success: {result.success}")
print(f" hits count: {len(result.hits)}")
print(f" applied_filter: {result.applied_filter}")
print(f" fallback_reason_code: {result.fallback_reason_code}")
print(f" duration_ms: {result.duration_ms}")
if result.hits:
print(f"\n 前3条结果:")
for i, hit in enumerate(result.hits[:3], 1):
text = hit.get('text', '')[:80] + '...' if hit.get('text') else 'N/A'
score = hit.get('score', 0)
metadata = hit.get('metadata', {})
print(f" {i}. [score={score:.4f}] {text}")
print(f" metadata: {metadata}")
else:
print(f"\n ⚠️ 没有命中任何结果")
except Exception as e:
print(f"\n ❌ 错误: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
asyncio.run(test_kb_search())