94 lines
3.2 KiB
Python
94 lines
3.2 KiB
Python
|
|
"""
|
|||
|
|
测试 kb_search_dynamic 工具是否能用给定的参数查出数据
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
import asyncio
|
|||
|
|
import sys
|
|||
|
|
from pathlib import Path
|
|||
|
|
|
|||
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
|||
|
|
|
|||
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
|||
|
|
from sqlalchemy.orm import sessionmaker
|
|||
|
|
from app.services.mid.kb_search_dynamic_tool import KbSearchDynamicTool
|
|||
|
|
from app.core.config import get_settings
|
|||
|
|
|
|||
|
|
|
|||
|
|
async def test_kb_search():
|
|||
|
|
"""测试知识库搜索"""
|
|||
|
|
settings = get_settings()
|
|||
|
|
|
|||
|
|
# 创建数据库会话
|
|||
|
|
engine = create_async_engine(settings.database_url)
|
|||
|
|
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
|||
|
|
|
|||
|
|
async with async_session() as session:
|
|||
|
|
tool = KbSearchDynamicTool(session=session)
|
|||
|
|
|
|||
|
|
# 测试参数
|
|||
|
|
test_cases = [
|
|||
|
|
{
|
|||
|
|
"name": "完整参数(含context过滤)",
|
|||
|
|
"params": {
|
|||
|
|
"query": "三年级语文学习",
|
|||
|
|
"tenant_id": "szmp@ash@2026",
|
|||
|
|
"scene": "学习方案",
|
|||
|
|
"top_k": 5,
|
|||
|
|
"context": {"grade": "三年级", "subject": "语文"},
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "简化参数(无context)",
|
|||
|
|
"params": {
|
|||
|
|
"query": "三年级语文学习",
|
|||
|
|
"tenant_id": "szmp@ash@2026",
|
|||
|
|
"scene": "学习方案",
|
|||
|
|
"top_k": 5,
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
{
|
|||
|
|
"name": "仅query和tenant_id",
|
|||
|
|
"params": {
|
|||
|
|
"query": "三年级语文学习",
|
|||
|
|
"tenant_id": "szmp@ash@2026",
|
|||
|
|
"top_k": 5,
|
|||
|
|
}
|
|||
|
|
},
|
|||
|
|
]
|
|||
|
|
|
|||
|
|
for test_case in test_cases:
|
|||
|
|
print(f"\n{'='*80}")
|
|||
|
|
print(f"测试: {test_case['name']}")
|
|||
|
|
print(f"{'='*80}")
|
|||
|
|
print(f"参数: {test_case['params']}")
|
|||
|
|
|
|||
|
|
try:
|
|||
|
|
result = await tool.execute(**test_case['params'])
|
|||
|
|
|
|||
|
|
print(f"\n结果:")
|
|||
|
|
print(f" success: {result.success}")
|
|||
|
|
print(f" hits count: {len(result.hits)}")
|
|||
|
|
print(f" applied_filter: {result.applied_filter}")
|
|||
|
|
print(f" fallback_reason_code: {result.fallback_reason_code}")
|
|||
|
|
print(f" duration_ms: {result.duration_ms}")
|
|||
|
|
|
|||
|
|
if result.hits:
|
|||
|
|
print(f"\n 前3条结果:")
|
|||
|
|
for i, hit in enumerate(result.hits[:3], 1):
|
|||
|
|
text = hit.get('text', '')[:80] + '...' if hit.get('text') else 'N/A'
|
|||
|
|
score = hit.get('score', 0)
|
|||
|
|
metadata = hit.get('metadata', {})
|
|||
|
|
print(f" {i}. [score={score:.4f}] {text}")
|
|||
|
|
print(f" metadata: {metadata}")
|
|||
|
|
else:
|
|||
|
|
print(f"\n ⚠️ 没有命中任何结果")
|
|||
|
|
|
|||
|
|
except Exception as e:
|
|||
|
|
print(f"\n ❌ 错误: {e}")
|
|||
|
|
import traceback
|
|||
|
|
traceback.print_exc()
|
|||
|
|
|
|||
|
|
|
|||
|
|
if __name__ == "__main__":
|
|||
|
|
asyncio.run(test_kb_search())
|