ai-robot-core/ai-service/tests/test_step_kb_binding.py

329 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Test cases for Step-KB Binding feature.
[Step-KB-Binding] 步骤关联知识库功能的测试用例
测试覆盖:
1. 步骤配置的增删改查与参数校验
2. 配置步骤KB范围后检索仅在范围内发生
3. 未配置时回退原逻辑
4. 多知识库同名内容场景下,步骤约束生效
5. trace 字段完整性校验
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from dataclasses import dataclass
from typing import Any
class TestStepKbBindingModel:
"""测试步骤KB绑定数据模型"""
def test_flow_step_with_kb_binding_fields(self):
"""测试 FlowStep 包含 KB 绑定字段"""
from app.models.entities import FlowStep
step = FlowStep(
step_no=1,
content="测试步骤",
allowed_kb_ids=["kb-1", "kb-2"],
preferred_kb_ids=["kb-1"],
kb_query_hint="查找产品相关信息",
max_kb_calls_per_step=2,
)
assert step.allowed_kb_ids == ["kb-1", "kb-2"]
assert step.preferred_kb_ids == ["kb-1"]
assert step.kb_query_hint == "查找产品相关信息"
assert step.max_kb_calls_per_step == 2
def test_flow_step_without_kb_binding(self):
"""测试 FlowStep 不配置 KB 绑定时的默认值"""
from app.models.entities import FlowStep
step = FlowStep(
step_no=1,
content="测试步骤",
)
assert step.allowed_kb_ids is None
assert step.preferred_kb_ids is None
assert step.kb_query_hint is None
assert step.max_kb_calls_per_step is None
def test_max_kb_calls_validation(self):
"""测试 max_kb_calls_per_step 的范围校验"""
from app.models.entities import FlowStep
from pydantic import ValidationError
# 有效范围 1-5
step = FlowStep(step_no=1, content="test", max_kb_calls_per_step=3)
assert step.max_kb_calls_per_step == 3
# 超出上限
with pytest.raises(Exception): # ValidationError
FlowStep(step_no=1, content="test", max_kb_calls_per_step=10)
class TestStepKbConfig:
"""测试 StepKbConfig 数据类"""
def test_step_kb_config_creation(self):
"""测试 StepKbConfig 创建"""
from app.services.mid.kb_search_dynamic_tool import StepKbConfig
config = StepKbConfig(
allowed_kb_ids=["kb-1", "kb-2"],
preferred_kb_ids=["kb-1"],
kb_query_hint="查找产品信息",
max_kb_calls=2,
step_id="flow-1_step_1",
)
assert config.allowed_kb_ids == ["kb-1", "kb-2"]
assert config.preferred_kb_ids == ["kb-1"]
assert config.kb_query_hint == "查找产品信息"
assert config.max_kb_calls == 2
assert config.step_id == "flow-1_step_1"
def test_step_kb_config_defaults(self):
"""测试 StepKbConfig 默认值"""
from app.services.mid.kb_search_dynamic_tool import StepKbConfig
config = StepKbConfig()
assert config.allowed_kb_ids is None
assert config.preferred_kb_ids is None
assert config.kb_query_hint is None
assert config.max_kb_calls == 1
assert config.step_id is None
class TestKbSearchDynamicToolWithStepConfig:
"""测试 KbSearchDynamicTool 与步骤配置的集成"""
@pytest.mark.asyncio
async def test_kb_search_with_allowed_kb_ids(self):
"""测试配置 allowed_kb_ids 后检索范围受限"""
from app.services.mid.kb_search_dynamic_tool import (
KbSearchDynamicTool,
KbSearchDynamicConfig,
StepKbConfig,
)
mock_session = MagicMock()
mock_timeout_governor = MagicMock()
tool = KbSearchDynamicTool(
session=mock_session,
timeout_governor=mock_timeout_governor,
config=KbSearchDynamicConfig(enabled=True),
)
step_config = StepKbConfig(
allowed_kb_ids=["kb-allowed-1", "kb-allowed-2"],
step_id="test_step",
)
with patch.object(tool, '_do_retrieve', new_callable=AsyncMock) as mock_retrieve:
mock_retrieve.return_value = [
{"id": "1", "content": "test", "score": 0.8, "metadata": {"kb_id": "kb-allowed-1"}}
]
result = await tool.execute(
query="测试查询",
tenant_id="tenant-1",
step_kb_config=step_config,
)
# 验证检索调用时传入了正确的 kb_ids
call_args = mock_retrieve.call_args
assert call_args[1]['step_kb_config'] == step_config
# 验证返回结果包含 step_kb_binding 信息
assert result.step_kb_binding is not None
assert result.step_kb_binding['allowed_kb_ids'] == ["kb-allowed-1", "kb-allowed-2"]
@pytest.mark.asyncio
async def test_kb_search_without_step_config(self):
"""测试未配置步骤KB时的回退行为"""
from app.services.mid.kb_search_dynamic_tool import (
KbSearchDynamicTool,
KbSearchDynamicConfig,
)
mock_session = MagicMock()
mock_timeout_governor = MagicMock()
tool = KbSearchDynamicTool(
session=mock_session,
timeout_governor=mock_timeout_governor,
config=KbSearchDynamicConfig(enabled=True),
)
with patch.object(tool, '_do_retrieve', new_callable=AsyncMock) as mock_retrieve:
mock_retrieve.return_value = [
{"id": "1", "content": "test", "score": 0.8, "metadata": {}}
]
result = await tool.execute(
query="测试查询",
tenant_id="tenant-1",
)
# 验证检索调用时未传入 step_kb_config
call_args = mock_retrieve.call_args
assert call_args[1]['step_kb_config'] is None
# 验证返回结果不包含 step_kb_binding
assert result.step_kb_binding is None
@pytest.mark.asyncio
async def test_kb_search_result_includes_used_kb_ids(self):
"""测试检索结果包含实际使用的知识库ID"""
from app.services.mid.kb_search_dynamic_tool import (
KbSearchDynamicTool,
KbSearchDynamicConfig,
StepKbConfig,
)
mock_session = MagicMock()
mock_timeout_governor = MagicMock()
tool = KbSearchDynamicTool(
session=mock_session,
timeout_governor=mock_timeout_governor,
config=KbSearchDynamicConfig(enabled=True),
)
step_config = StepKbConfig(
allowed_kb_ids=["kb-1", "kb-2"],
step_id="test_step",
)
with patch.object(tool, '_do_retrieve', new_callable=AsyncMock) as mock_retrieve:
mock_retrieve.return_value = [
{"id": "1", "content": "test1", "score": 0.9, "metadata": {"kb_id": "kb-1"}},
{"id": "2", "content": "test2", "score": 0.8, "metadata": {"kb_id": "kb-1"}},
{"id": "3", "content": "test3", "score": 0.7, "metadata": {"kb_id": "kb-2"}},
]
result = await tool.execute(
query="测试查询",
tenant_id="tenant-1",
step_kb_config=step_config,
)
# 验证 used_kb_ids 包含所有命中的知识库
assert result.step_kb_binding is not None
assert set(result.step_kb_binding['used_kb_ids']) == {"kb-1", "kb-2"}
assert result.step_kb_binding['kb_hit'] is True
class TestTraceInfoStepKbBinding:
"""测试 TraceInfo 中的 step_kb_binding 字段"""
def test_trace_info_with_step_kb_binding(self):
"""测试 TraceInfo 包含 step_kb_binding 字段"""
from app.models.mid.schemas import TraceInfo, ExecutionMode
trace = TraceInfo(
mode=ExecutionMode.AGENT,
step_kb_binding={
"step_id": "flow-1_step_2",
"allowed_kb_ids": ["kb-1", "kb-2"],
"used_kb_ids": ["kb-1"],
"kb_hit": True,
},
)
assert trace.step_kb_binding is not None
assert trace.step_kb_binding['step_id'] == "flow-1_step_2"
assert trace.step_kb_binding['allowed_kb_ids'] == ["kb-1", "kb-2"]
assert trace.step_kb_binding['used_kb_ids'] == ["kb-1"]
def test_trace_info_without_step_kb_binding(self):
"""测试 TraceInfo 默认不包含 step_kb_binding"""
from app.models.mid.schemas import TraceInfo, ExecutionMode
trace = TraceInfo(mode=ExecutionMode.AGENT)
assert trace.step_kb_binding is None
class TestFlowStepKbBindingIntegration:
"""测试流程步骤与KB绑定的集成"""
def test_script_flow_steps_with_kb_binding(self):
"""测试 ScriptFlow 的 steps 包含 KB 绑定配置"""
from app.models.entities import ScriptFlowCreate
flow_create = ScriptFlowCreate(
name="测试流程",
steps=[
{
"step_no": 1,
"content": "步骤1",
"allowed_kb_ids": ["kb-1"],
"preferred_kb_ids": None,
"kb_query_hint": "查找产品信息",
"max_kb_calls_per_step": 2,
},
{
"step_no": 2,
"content": "步骤2",
# 不配置 KB 绑定
},
],
)
assert flow_create.steps[0]['allowed_kb_ids'] == ["kb-1"]
assert flow_create.steps[0]['kb_query_hint'] == "查找产品信息"
assert flow_create.steps[1].get('allowed_kb_ids') is None
class TestKbBindingLogging:
"""测试 KB 绑定的日志记录"""
@pytest.mark.asyncio
async def test_step_kb_config_logging(self, caplog):
"""测试步骤KB配置的日志记录"""
import logging
from app.services.mid.kb_search_dynamic_tool import (
KbSearchDynamicTool,
KbSearchDynamicConfig,
StepKbConfig,
)
mock_session = MagicMock()
mock_timeout_governor = MagicMock()
tool = KbSearchDynamicTool(
session=mock_session,
timeout_governor=mock_timeout_governor,
config=KbSearchDynamicConfig(enabled=True),
)
step_config = StepKbConfig(
allowed_kb_ids=["kb-1"],
step_id="flow-1_step_1",
)
with patch.object(tool, '_do_retrieve', new_callable=AsyncMock) as mock_retrieve:
mock_retrieve.return_value = []
with caplog.at_level(logging.INFO):
await tool.execute(
query="测试",
tenant_id="tenant-1",
step_kb_config=step_config,
)
# 验证日志包含 Step-KB-Binding 标记
assert any("Step-KB-Binding" in record.message for record in caplog.records)
# 运行测试的入口
if __name__ == "__main__":
pytest.main([__file__, "-v"])