121 lines
4.5 KiB
Python
121 lines
4.5 KiB
Python
|
|
"""
|
||
|
|
测试 kb_search_dynamic 工具 - 课程咨询场景
|
||
|
|
"""
|
||
|
|
|
||
|
|
import asyncio
|
||
|
|
import sys
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||
|
|
|
||
|
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||
|
|
from sqlalchemy.orm import sessionmaker
|
||
|
|
from app.services.mid.kb_search_dynamic_tool import (
|
||
|
|
KbSearchDynamicTool,
|
||
|
|
KbSearchDynamicConfig,
|
||
|
|
StepKbConfig,
|
||
|
|
)
|
||
|
|
from app.core.config import get_settings
|
||
|
|
|
||
|
|
|
||
|
|
async def test_kb_search():
|
||
|
|
"""测试知识库搜索 - 课程咨询场景"""
|
||
|
|
settings = get_settings()
|
||
|
|
|
||
|
|
engine = create_async_engine(settings.database_url)
|
||
|
|
async_session = sessionmaker(engine, class_=AsyncSession, expire_on_commit=False)
|
||
|
|
|
||
|
|
async with async_session() as session:
|
||
|
|
config = KbSearchDynamicConfig(
|
||
|
|
enabled=True,
|
||
|
|
top_k=10,
|
||
|
|
timeout_ms=15000,
|
||
|
|
min_score_threshold=0.3,
|
||
|
|
)
|
||
|
|
|
||
|
|
tool = KbSearchDynamicTool(session=session, config=config)
|
||
|
|
|
||
|
|
course_kb_id = "75c465fe-277d-455d-a30b-4b168adcc03b"
|
||
|
|
|
||
|
|
step_kb_config = StepKbConfig(
|
||
|
|
allowed_kb_ids=[course_kb_id],
|
||
|
|
preferred_kb_ids=[course_kb_id],
|
||
|
|
step_id="test_course_query",
|
||
|
|
)
|
||
|
|
|
||
|
|
test_params = {
|
||
|
|
"query": "课程介绍",
|
||
|
|
"tenant_id": "szmp@ash@2026",
|
||
|
|
"top_k": 10,
|
||
|
|
"context": {
|
||
|
|
"grade": "五年级",
|
||
|
|
},
|
||
|
|
"step_kb_config": step_kb_config,
|
||
|
|
}
|
||
|
|
|
||
|
|
print(f"\n{'='*80}")
|
||
|
|
print(f"测试: kb_search_dynamic - 课程知识库")
|
||
|
|
print(f"{'='*80}")
|
||
|
|
print(f"参数: query={test_params['query']}")
|
||
|
|
print(f" tenant_id={test_params['tenant_id']}")
|
||
|
|
print(f" context={test_params['context']}")
|
||
|
|
print(f" step_kb_config.allowed_kb_ids={step_kb_config.allowed_kb_ids}")
|
||
|
|
print(f"超时设置: {config.timeout_ms}ms")
|
||
|
|
print(f"最低分数阈值: {config.min_score_threshold}")
|
||
|
|
|
||
|
|
try:
|
||
|
|
result = await tool.execute(**test_params)
|
||
|
|
|
||
|
|
print(f"\n结果:")
|
||
|
|
print(f" success: {result.success}")
|
||
|
|
print(f" hits count: {len(result.hits)}")
|
||
|
|
print(f" applied_filter: {result.applied_filter}")
|
||
|
|
print(f" fallback_reason_code: {result.fallback_reason_code}")
|
||
|
|
print(f" duration_ms: {result.duration_ms}")
|
||
|
|
|
||
|
|
if result.filter_debug:
|
||
|
|
print(f" filter_debug: {result.filter_debug}")
|
||
|
|
|
||
|
|
if result.step_kb_binding:
|
||
|
|
print(f" step_kb_binding: {result.step_kb_binding}")
|
||
|
|
|
||
|
|
if result.tool_trace:
|
||
|
|
print(f"\n Tool Trace:")
|
||
|
|
print(f" tool_name: {result.tool_trace.tool_name}")
|
||
|
|
print(f" status: {result.tool_trace.status}")
|
||
|
|
print(f" duration_ms: {result.tool_trace.duration_ms}")
|
||
|
|
print(f" args_digest: {result.tool_trace.args_digest}")
|
||
|
|
print(f" result_digest: {result.tool_trace.result_digest}")
|
||
|
|
if hasattr(result.tool_trace, 'arguments') and result.tool_trace.arguments:
|
||
|
|
print(f" arguments: {result.tool_trace.arguments}")
|
||
|
|
|
||
|
|
if result.hits:
|
||
|
|
print(f"\n 检索结果 (共 {len(result.hits)} 条):")
|
||
|
|
for i, hit in enumerate(result.hits, 1):
|
||
|
|
text = hit.get('text', '')
|
||
|
|
text_preview = text[:200] + '...' if len(text) > 200 else text
|
||
|
|
score = hit.get('score', 0)
|
||
|
|
metadata = hit.get('metadata', {})
|
||
|
|
collection = hit.get('collection', 'unknown')
|
||
|
|
kb_id = hit.get('kb_id', 'unknown')
|
||
|
|
print(f"\n [{i}] score={score:.4f}")
|
||
|
|
print(f" collection: {collection}")
|
||
|
|
print(f" kb_id: {kb_id}")
|
||
|
|
print(f" metadata: {metadata}")
|
||
|
|
print(f" text: {text_preview}")
|
||
|
|
else:
|
||
|
|
print(f"\n ⚠️ 没有命中任何结果")
|
||
|
|
print(f" 请检查:")
|
||
|
|
print(f" 1. 知识库是否有数据")
|
||
|
|
print(f" 2. 向量是否正确生成")
|
||
|
|
print(f" 3. 过滤条件是否过于严格")
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
print(f"\n ❌ 错误: {e}")
|
||
|
|
import traceback
|
||
|
|
traceback.print_exc()
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == "__main__":
|
||
|
|
asyncio.run(test_kb_search())
|