420 lines
15 KiB
Python
420 lines
15 KiB
Python
|
|
"""
|
||
|
|
Tests for Slot Backfill Service.
|
||
|
|
[AC-MRS-SLOT-BACKFILL-01] 槽位值回填确认测试
|
||
|
|
"""
|
||
|
|
|
||
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||
|
|
|
||
|
|
import pytest
|
||
|
|
|
||
|
|
from app.models.mid.schemas import SlotSource
|
||
|
|
from app.services.mid.slot_backfill_service import (
|
||
|
|
BackfillResult,
|
||
|
|
BackfillStatus,
|
||
|
|
BatchBackfillResult,
|
||
|
|
SlotBackfillService,
|
||
|
|
create_slot_backfill_service,
|
||
|
|
)
|
||
|
|
from app.services.mid.slot_manager import SlotWriteResult
|
||
|
|
from app.services.mid.slot_strategy_executor import (
|
||
|
|
StrategyChainResult,
|
||
|
|
StrategyStepResult,
|
||
|
|
)
|
||
|
|
|
||
|
|
|
||
|
|
class TestBackfillResult:
|
||
|
|
"""BackfillResult 测试"""
|
||
|
|
|
||
|
|
def test_is_success(self):
|
||
|
|
"""测试成功判断"""
|
||
|
|
result = BackfillResult(status=BackfillStatus.SUCCESS)
|
||
|
|
assert result.is_success() is True
|
||
|
|
|
||
|
|
result = BackfillResult(status=BackfillStatus.VALIDATION_FAILED)
|
||
|
|
assert result.is_success() is False
|
||
|
|
|
||
|
|
def test_needs_ask_back(self):
|
||
|
|
"""测试需要追问判断"""
|
||
|
|
result = BackfillResult(status=BackfillStatus.VALIDATION_FAILED)
|
||
|
|
assert result.needs_ask_back() is True
|
||
|
|
|
||
|
|
result = BackfillResult(status=BackfillStatus.EXTRACTION_FAILED)
|
||
|
|
assert result.needs_ask_back() is True
|
||
|
|
|
||
|
|
result = BackfillResult(status=BackfillStatus.SUCCESS)
|
||
|
|
assert result.needs_ask_back() is False
|
||
|
|
|
||
|
|
def test_needs_confirmation(self):
|
||
|
|
"""测试需要确认判断"""
|
||
|
|
result = BackfillResult(status=BackfillStatus.NEEDS_CONFIRMATION)
|
||
|
|
assert result.needs_confirmation() is True
|
||
|
|
|
||
|
|
result = BackfillResult(status=BackfillStatus.SUCCESS)
|
||
|
|
assert result.needs_confirmation() is False
|
||
|
|
|
||
|
|
def test_to_dict(self):
|
||
|
|
"""测试转换为字典"""
|
||
|
|
result = BackfillResult(
|
||
|
|
status=BackfillStatus.SUCCESS,
|
||
|
|
slot_key="region",
|
||
|
|
value="北京",
|
||
|
|
normalized_value="北京",
|
||
|
|
source="user_confirmed",
|
||
|
|
confidence=1.0,
|
||
|
|
)
|
||
|
|
d = result.to_dict()
|
||
|
|
assert d["status"] == "success"
|
||
|
|
assert d["slot_key"] == "region"
|
||
|
|
assert d["value"] == "北京"
|
||
|
|
assert d["source"] == "user_confirmed"
|
||
|
|
|
||
|
|
|
||
|
|
class TestBatchBackfillResult:
|
||
|
|
"""BatchBackfillResult 测试"""
|
||
|
|
|
||
|
|
def test_add_result(self):
|
||
|
|
"""测试添加结果"""
|
||
|
|
batch = BatchBackfillResult()
|
||
|
|
|
||
|
|
batch.add_result(BackfillResult(status=BackfillStatus.SUCCESS, slot_key="region"))
|
||
|
|
batch.add_result(BackfillResult(status=BackfillStatus.VALIDATION_FAILED, slot_key="product"))
|
||
|
|
batch.add_result(BackfillResult(status=BackfillStatus.NEEDS_CONFIRMATION, slot_key="grade"))
|
||
|
|
|
||
|
|
assert batch.success_count == 1
|
||
|
|
assert batch.failed_count == 1
|
||
|
|
assert batch.confirmation_needed_count == 1
|
||
|
|
|
||
|
|
def test_get_ask_back_prompts(self):
|
||
|
|
"""测试获取追问提示"""
|
||
|
|
batch = BatchBackfillResult()
|
||
|
|
|
||
|
|
batch.add_result(BackfillResult(
|
||
|
|
status=BackfillStatus.VALIDATION_FAILED,
|
||
|
|
ask_back_prompt="请重新输入",
|
||
|
|
))
|
||
|
|
batch.add_result(BackfillResult(
|
||
|
|
status=BackfillStatus.SUCCESS,
|
||
|
|
))
|
||
|
|
batch.add_result(BackfillResult(
|
||
|
|
status=BackfillStatus.EXTRACTION_FAILED,
|
||
|
|
ask_back_prompt="无法识别,请重试",
|
||
|
|
))
|
||
|
|
|
||
|
|
prompts = batch.get_ask_back_prompts()
|
||
|
|
assert len(prompts) == 2
|
||
|
|
assert "请重新输入" in prompts
|
||
|
|
assert "无法识别,请重试" in prompts
|
||
|
|
|
||
|
|
def test_get_confirmation_prompts(self):
|
||
|
|
"""测试获取确认提示"""
|
||
|
|
batch = BatchBackfillResult()
|
||
|
|
|
||
|
|
batch.add_result(BackfillResult(
|
||
|
|
status=BackfillStatus.NEEDS_CONFIRMATION,
|
||
|
|
confirmation_prompt="我理解您说的是「北京」,对吗?",
|
||
|
|
))
|
||
|
|
batch.add_result(BackfillResult(
|
||
|
|
status=BackfillStatus.SUCCESS,
|
||
|
|
))
|
||
|
|
|
||
|
|
prompts = batch.get_confirmation_prompts()
|
||
|
|
assert len(prompts) == 1
|
||
|
|
assert "北京" in prompts[0]
|
||
|
|
|
||
|
|
|
||
|
|
class TestSlotBackfillService:
|
||
|
|
"""SlotBackfillService 测试"""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_session(self):
|
||
|
|
"""创建 mock session"""
|
||
|
|
return AsyncMock()
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def mock_slot_manager(self):
|
||
|
|
"""创建 mock slot manager"""
|
||
|
|
manager = MagicMock()
|
||
|
|
manager.write_slot = AsyncMock()
|
||
|
|
manager.get_ask_back_prompt = AsyncMock(return_value="请提供信息")
|
||
|
|
return manager
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def service(self, mock_session, mock_slot_manager):
|
||
|
|
"""创建服务实例"""
|
||
|
|
return SlotBackfillService(
|
||
|
|
session=mock_session,
|
||
|
|
tenant_id="tenant_1",
|
||
|
|
session_id="session_1",
|
||
|
|
slot_manager=mock_slot_manager,
|
||
|
|
)
|
||
|
|
|
||
|
|
def test_confidence_thresholds(self, service):
|
||
|
|
"""测试置信度阈值"""
|
||
|
|
assert service.CONFIDENCE_THRESHOLD_LOW == 0.5
|
||
|
|
assert service.CONFIDENCE_THRESHOLD_HIGH == 0.8
|
||
|
|
|
||
|
|
def test_get_source_for_strategy(self, service):
|
||
|
|
"""测试策略到来源的映射"""
|
||
|
|
assert service._get_source_for_strategy("rule") == SlotSource.RULE_EXTRACTED.value
|
||
|
|
assert service._get_source_for_strategy("llm") == SlotSource.LLM_INFERRED.value
|
||
|
|
assert service._get_source_for_strategy("user_input") == SlotSource.USER_CONFIRMED.value
|
||
|
|
assert service._get_source_for_strategy("unknown") == "unknown"
|
||
|
|
|
||
|
|
def test_get_confidence_for_strategy(self, service):
|
||
|
|
"""测试来源到置信度的映射"""
|
||
|
|
assert service._get_confidence_for_strategy(SlotSource.USER_CONFIRMED.value) == 1.0
|
||
|
|
assert service._get_confidence_for_strategy(SlotSource.RULE_EXTRACTED.value) == 0.9
|
||
|
|
assert service._get_confidence_for_strategy(SlotSource.LLM_INFERRED.value) == 0.7
|
||
|
|
assert service._get_confidence_for_strategy("context") == 0.5
|
||
|
|
assert service._get_confidence_for_strategy(SlotSource.DEFAULT.value) == 0.3
|
||
|
|
|
||
|
|
def test_generate_confirmation_prompt(self, service):
|
||
|
|
"""测试生成确认提示"""
|
||
|
|
prompt = service._generate_confirmation_prompt("region", "北京")
|
||
|
|
assert "北京" in prompt
|
||
|
|
assert "对吗" in prompt
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_backfill_single_slot_success(self, service, mock_slot_manager):
|
||
|
|
"""测试单个槽位回填成功"""
|
||
|
|
mock_slot_manager.write_slot.return_value = SlotWriteResult(
|
||
|
|
success=True,
|
||
|
|
slot_key="region",
|
||
|
|
value="北京",
|
||
|
|
)
|
||
|
|
|
||
|
|
with patch.object(service, '_get_state_aggregator') as mock_agg:
|
||
|
|
mock_aggregator = AsyncMock()
|
||
|
|
mock_aggregator.update_slot = AsyncMock()
|
||
|
|
mock_agg.return_value = mock_aggregator
|
||
|
|
|
||
|
|
result = await service.backfill_single_slot(
|
||
|
|
slot_key="region",
|
||
|
|
candidate_value="北京",
|
||
|
|
source="user_confirmed",
|
||
|
|
confidence=1.0,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert result.status == BackfillStatus.SUCCESS
|
||
|
|
assert result.slot_key == "region"
|
||
|
|
assert result.normalized_value == "北京"
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_backfill_single_slot_validation_failed(self, service, mock_slot_manager):
|
||
|
|
"""测试单个槽位回填校验失败"""
|
||
|
|
from app.services.mid.slot_validation_service import SlotValidationError
|
||
|
|
|
||
|
|
mock_slot_manager.write_slot.return_value = SlotWriteResult(
|
||
|
|
success=False,
|
||
|
|
slot_key="region",
|
||
|
|
error=SlotValidationError(
|
||
|
|
slot_key="region",
|
||
|
|
error_code="INVALID_VALUE",
|
||
|
|
error_message="无效的地区",
|
||
|
|
),
|
||
|
|
ask_back_prompt="请提供有效的地区",
|
||
|
|
)
|
||
|
|
|
||
|
|
result = await service.backfill_single_slot(
|
||
|
|
slot_key="region",
|
||
|
|
candidate_value="无效地区",
|
||
|
|
source="user_confirmed",
|
||
|
|
confidence=1.0,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert result.status == BackfillStatus.VALIDATION_FAILED
|
||
|
|
assert result.ask_back_prompt == "请提供有效的地区"
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_backfill_single_slot_low_confidence(self, service, mock_slot_manager):
|
||
|
|
"""测试低置信度槽位需要确认"""
|
||
|
|
mock_slot_manager.write_slot.return_value = SlotWriteResult(
|
||
|
|
success=True,
|
||
|
|
slot_key="region",
|
||
|
|
value="北京",
|
||
|
|
)
|
||
|
|
|
||
|
|
with patch.object(service, '_get_state_aggregator') as mock_agg:
|
||
|
|
mock_aggregator = AsyncMock()
|
||
|
|
mock_aggregator.update_slot = AsyncMock()
|
||
|
|
mock_agg.return_value = mock_aggregator
|
||
|
|
|
||
|
|
result = await service.backfill_single_slot(
|
||
|
|
slot_key="region",
|
||
|
|
candidate_value="北京",
|
||
|
|
source="llm_inferred",
|
||
|
|
confidence=0.4,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert result.status == BackfillStatus.NEEDS_CONFIRMATION
|
||
|
|
assert result.confirmation_prompt is not None
|
||
|
|
assert "北京" in result.confirmation_prompt
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_backfill_multiple_slots(self, service, mock_slot_manager):
|
||
|
|
"""测试批量回填槽位"""
|
||
|
|
mock_slot_manager.write_slot.side_effect = [
|
||
|
|
SlotWriteResult(success=True, slot_key="region", value="北京"),
|
||
|
|
SlotWriteResult(success=True, slot_key="product", value="手机"),
|
||
|
|
SlotWriteResult(success=False, slot_key="grade", error=MagicMock()),
|
||
|
|
]
|
||
|
|
|
||
|
|
with patch.object(service, '_get_state_aggregator') as mock_agg:
|
||
|
|
mock_aggregator = AsyncMock()
|
||
|
|
mock_aggregator.update_slot = AsyncMock()
|
||
|
|
mock_agg.return_value = mock_aggregator
|
||
|
|
|
||
|
|
result = await service.backfill_multiple_slots(
|
||
|
|
candidates={
|
||
|
|
"region": "北京",
|
||
|
|
"product": "手机",
|
||
|
|
"grade": "无效等级",
|
||
|
|
},
|
||
|
|
source="user_confirmed",
|
||
|
|
)
|
||
|
|
|
||
|
|
assert result.success_count == 2
|
||
|
|
assert result.failed_count == 1
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_confirm_low_confidence_slot_confirmed(self, service):
|
||
|
|
"""测试确认低置信度槽位 - 用户确认"""
|
||
|
|
with patch.object(service, '_get_state_aggregator') as mock_agg:
|
||
|
|
mock_aggregator = AsyncMock()
|
||
|
|
mock_aggregator.update_slot = AsyncMock()
|
||
|
|
mock_agg.return_value = mock_aggregator
|
||
|
|
|
||
|
|
result = await service.confirm_low_confidence_slot(
|
||
|
|
slot_key="region",
|
||
|
|
confirmed=True,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert result.status == BackfillStatus.SUCCESS
|
||
|
|
assert result.source == SlotSource.USER_CONFIRMED.value
|
||
|
|
assert result.confidence == 1.0
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_confirm_low_confidence_slot_rejected(self, service, mock_slot_manager):
|
||
|
|
"""测试确认低置信度槽位 - 用户拒绝"""
|
||
|
|
with patch.object(service, '_get_state_aggregator') as mock_agg:
|
||
|
|
mock_aggregator = AsyncMock()
|
||
|
|
mock_aggregator.clear_slot = AsyncMock()
|
||
|
|
mock_agg.return_value = mock_aggregator
|
||
|
|
|
||
|
|
result = await service.confirm_low_confidence_slot(
|
||
|
|
slot_key="region",
|
||
|
|
confirmed=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
assert result.status == BackfillStatus.VALIDATION_FAILED
|
||
|
|
assert result.ask_back_prompt is not None
|
||
|
|
|
||
|
|
|
||
|
|
class TestCreateSlotBackfillService:
|
||
|
|
"""create_slot_backfill_service 工厂函数测试"""
|
||
|
|
|
||
|
|
def test_create(self):
|
||
|
|
"""测试创建服务实例"""
|
||
|
|
mock_session = AsyncMock()
|
||
|
|
service = create_slot_backfill_service(
|
||
|
|
session=mock_session,
|
||
|
|
tenant_id="tenant_1",
|
||
|
|
session_id="session_1",
|
||
|
|
)
|
||
|
|
assert isinstance(service, SlotBackfillService)
|
||
|
|
assert service._tenant_id == "tenant_1"
|
||
|
|
assert service._session_id == "session_1"
|
||
|
|
|
||
|
|
|
||
|
|
class TestBackfillFromUserResponse:
|
||
|
|
"""从用户回复回填测试"""
|
||
|
|
|
||
|
|
@pytest.fixture
|
||
|
|
def service(self):
|
||
|
|
"""创建服务实例"""
|
||
|
|
mock_session = AsyncMock()
|
||
|
|
mock_slot_def_service = AsyncMock()
|
||
|
|
|
||
|
|
service = SlotBackfillService(
|
||
|
|
session=mock_session,
|
||
|
|
tenant_id="tenant_1",
|
||
|
|
session_id="session_1",
|
||
|
|
)
|
||
|
|
service._slot_def_service = mock_slot_def_service
|
||
|
|
return service
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_backfill_from_user_response_success(self, service):
|
||
|
|
"""测试从用户回复成功提取并回填"""
|
||
|
|
mock_slot_def = MagicMock()
|
||
|
|
mock_slot_def.type = "string"
|
||
|
|
mock_slot_def.validation_rule = None
|
||
|
|
mock_slot_def.ask_back_prompt = "请提供地区"
|
||
|
|
|
||
|
|
service._slot_def_service.get_slot_definition_by_key = AsyncMock(
|
||
|
|
return_value=mock_slot_def
|
||
|
|
)
|
||
|
|
|
||
|
|
with patch.object(service, '_extract_value') as mock_extract:
|
||
|
|
mock_extract.return_value = StrategyChainResult(
|
||
|
|
slot_key="region",
|
||
|
|
success=True,
|
||
|
|
final_value="北京",
|
||
|
|
final_strategy="rule",
|
||
|
|
)
|
||
|
|
|
||
|
|
with patch.object(service, 'backfill_single_slot') as mock_backfill:
|
||
|
|
mock_backfill.return_value = BackfillResult(
|
||
|
|
status=BackfillStatus.SUCCESS,
|
||
|
|
slot_key="region",
|
||
|
|
value="北京",
|
||
|
|
)
|
||
|
|
|
||
|
|
result = await service.backfill_from_user_response(
|
||
|
|
user_response="我想查询北京的产品",
|
||
|
|
expected_slots=["region"],
|
||
|
|
)
|
||
|
|
|
||
|
|
assert result.success_count == 1
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_backfill_from_user_response_no_definition(self, service):
|
||
|
|
"""测试槽位定义不存在"""
|
||
|
|
service._slot_def_service.get_slot_definition_by_key = AsyncMock(
|
||
|
|
return_value=None
|
||
|
|
)
|
||
|
|
|
||
|
|
result = await service.backfill_from_user_response(
|
||
|
|
user_response="我想查询北京的产品",
|
||
|
|
expected_slots=["unknown_slot"],
|
||
|
|
)
|
||
|
|
|
||
|
|
assert result.success_count == 0
|
||
|
|
assert result.failed_count == 0
|
||
|
|
|
||
|
|
@pytest.mark.asyncio
|
||
|
|
async def test_backfill_from_user_response_extraction_failed(self, service):
|
||
|
|
"""测试提取失败"""
|
||
|
|
mock_slot_def = MagicMock()
|
||
|
|
mock_slot_def.type = "string"
|
||
|
|
mock_slot_def.validation_rule = None
|
||
|
|
mock_slot_def.ask_back_prompt = "请提供地区"
|
||
|
|
|
||
|
|
service._slot_def_service.get_slot_definition_by_key = AsyncMock(
|
||
|
|
return_value=mock_slot_def
|
||
|
|
)
|
||
|
|
|
||
|
|
with patch.object(service, '_extract_value') as mock_extract:
|
||
|
|
mock_extract.return_value = StrategyChainResult(
|
||
|
|
slot_key="region",
|
||
|
|
success=False,
|
||
|
|
)
|
||
|
|
|
||
|
|
result = await service.backfill_from_user_response(
|
||
|
|
user_response="我想查询产品",
|
||
|
|
expected_slots=["region"],
|
||
|
|
)
|
||
|
|
|
||
|
|
assert result.failed_count == 1
|
||
|
|
assert result.results[0].status == BackfillStatus.EXTRACTION_FAILED
|