ai-robot-core/ai-service/tests/test_slot_definition_servic...

334 lines
12 KiB
Python

"""
Unit tests for SlotDefinitionService.
[AC-MRS-07,08,16] 验证槽位定义管理功能
"""
import uuid
import pytest
from unittest.mock import AsyncMock, MagicMock
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.slot_definition_service import SlotDefinitionService
from app.models.entities import (
SlotDefinition,
SlotDefinitionCreate,
SlotDefinitionUpdate,
MetadataFieldDefinition,
)
class TestSlotDefinitionService:
"""[AC-MRS-07,08,16] SlotDefinitionService 测试"""
@pytest.fixture
def mock_session(self):
"""Mock AsyncSession"""
session = MagicMock(spec=AsyncSession)
session.execute = AsyncMock()
session.add = MagicMock()
session.flush = AsyncMock()
session.delete = AsyncMock()
return session
@pytest.fixture
def service(self, mock_session):
"""Create service instance"""
return SlotDefinitionService(mock_session)
@pytest.mark.asyncio
async def test_list_slot_definitions(self, service, mock_session):
"""列出槽位定义"""
mock_slot = MagicMock(spec=SlotDefinition)
mock_slot.id = uuid.uuid4()
mock_slot.slot_key = "grade"
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [mock_slot]
mock_session.execute.return_value = mock_result
slots = await service.list_slot_definitions("test-tenant")
assert len(slots) == 1
assert slots[0].slot_key == "grade"
@pytest.mark.asyncio
async def test_list_slot_definitions_filter_required(self, service, mock_session):
"""按必填过滤槽位定义"""
mock_slot = MagicMock(spec=SlotDefinition)
mock_slot.required = True
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [mock_slot]
mock_session.execute.return_value = mock_result
slots = await service.list_slot_definitions("test-tenant", required=True)
assert len(slots) == 1
assert slots[0].required is True
@pytest.mark.asyncio
async def test_get_slot_definition(self, service, mock_session):
"""获取单个槽位定义"""
slot_id = uuid.uuid4()
mock_slot = MagicMock(spec=SlotDefinition)
mock_slot.id = slot_id
mock_slot.slot_key = "grade"
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_slot
mock_session.execute.return_value = mock_result
slot = await service.get_slot_definition("test-tenant", str(slot_id))
assert slot is not None
assert slot.slot_key == "grade"
@pytest.mark.asyncio
async def test_get_slot_definition_not_found(self, service, mock_session):
"""获取不存在的槽位定义"""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
slot = await service.get_slot_definition("test-tenant", str(uuid.uuid4()))
assert slot is None
@pytest.mark.asyncio
async def test_get_slot_definition_by_key(self, service, mock_session):
"""通过 slot_key 获取槽位定义"""
mock_slot = MagicMock(spec=SlotDefinition)
mock_slot.slot_key = "grade"
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_slot
mock_session.execute.return_value = mock_result
slot = await service.get_slot_definition_by_key("test-tenant", "grade")
assert slot is not None
assert slot.slot_key == "grade"
@pytest.mark.asyncio
async def test_create_slot_definition(self, service, mock_session):
"""[AC-MRS-07] 创建槽位定义"""
slot_create = SlotDefinitionCreate(
slot_key="grade",
type="string",
required=True,
extract_strategy="llm",
ask_back_prompt="请输入年级",
)
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
slot = await service.create_slot_definition("test-tenant", slot_create)
assert slot is not None
mock_session.add.assert_called_once()
mock_session.flush.assert_called_once()
@pytest.mark.asyncio
async def test_create_slot_definition_invalid_key(self, service):
"""[AC-MRS-07] 创建无效 slot_key 抛出异常"""
slot_create = SlotDefinitionCreate(
slot_key="InvalidKey",
type="string",
required=True,
)
with pytest.raises(ValueError) as exc_info:
await service.create_slot_definition("test-tenant", slot_create)
assert "格式不正确" in str(exc_info.value)
@pytest.mark.asyncio
async def test_create_slot_definition_duplicate_key(self, service, mock_session):
"""[AC-MRS-07] 创建重复 slot_key 抛出异常"""
existing_slot = MagicMock(spec=SlotDefinition)
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = existing_slot
mock_session.execute.return_value = mock_result
slot_create = SlotDefinitionCreate(
slot_key="grade",
type="string",
required=True,
)
with pytest.raises(ValueError) as exc_info:
await service.create_slot_definition("test-tenant", slot_create)
assert "已存在" in str(exc_info.value)
@pytest.mark.asyncio
async def test_create_slot_definition_invalid_type(self, service, mock_session):
"""[AC-MRS-07] 创建无效类型抛出异常"""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
slot_create = SlotDefinitionCreate(
slot_key="grade",
type="invalid_type",
required=True,
)
with pytest.raises(ValueError) as exc_info:
await service.create_slot_definition("test-tenant", slot_create)
assert "无效的槽位类型" in str(exc_info.value)
@pytest.mark.asyncio
async def test_create_slot_definition_with_linked_field(self, service, mock_session):
"""[AC-MRS-08] 创建槽位定义并关联元数据字段"""
field_id = uuid.uuid4()
mock_field = MagicMock(spec=MetadataFieldDefinition)
mock_field.id = field_id
slot_result = MagicMock()
slot_result.scalar_one_or_none.return_value = None
field_result = MagicMock()
field_result.scalar_one_or_none.return_value = mock_field
mock_session.execute.side_effect = [slot_result, field_result]
slot_create = SlotDefinitionCreate(
slot_key="grade",
type="string",
required=True,
linked_field_id=str(field_id),
)
slot = await service.create_slot_definition("test-tenant", slot_create)
assert slot is not None
mock_session.add.assert_called_once()
@pytest.mark.asyncio
async def test_create_slot_definition_linked_field_not_found(self, service, mock_session):
"""[AC-MRS-08] 关联字段不存在抛出异常"""
field_id = uuid.uuid4()
slot_result = MagicMock()
slot_result.scalar_one_or_none.return_value = None
field_result = MagicMock()
field_result.scalar_one_or_none.return_value = None
mock_session.execute.side_effect = [slot_result, field_result]
slot_create = SlotDefinitionCreate(
slot_key="grade",
type="string",
required=True,
linked_field_id=str(field_id),
)
with pytest.raises(ValueError) as exc_info:
await service.create_slot_definition("test-tenant", slot_create)
assert "关联的元数据字段" in str(exc_info.value)
@pytest.mark.asyncio
async def test_update_slot_definition(self, service, mock_session):
"""更新槽位定义"""
slot_id = uuid.uuid4()
mock_slot = MagicMock(spec=SlotDefinition)
mock_slot.id = slot_id
mock_slot.slot_key = "grade"
mock_slot.type = "string"
mock_slot.required = False
mock_slot.extract_strategy = None
mock_slot.validation_rule = None
mock_slot.ask_back_prompt = None
mock_slot.default_value = None
mock_slot.linked_field_id = None
mock_slot.updated_at = None
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_slot
mock_session.execute.return_value = mock_result
slot_update = SlotDefinitionUpdate(
required=True,
ask_back_prompt="请输入年级",
)
slot = await service.update_slot_definition("test-tenant", str(slot_id), slot_update)
assert slot is not None
mock_session.flush.assert_called_once()
@pytest.mark.asyncio
async def test_update_slot_definition_not_found(self, service, mock_session):
"""更新不存在的槽位定义"""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
slot_update = SlotDefinitionUpdate(required=True)
slot = await service.update_slot_definition("test-tenant", str(uuid.uuid4()), slot_update)
assert slot is None
@pytest.mark.asyncio
async def test_delete_slot_definition(self, service, mock_session):
"""[AC-MRS-16] 删除槽位定义"""
slot_id = uuid.uuid4()
mock_slot = MagicMock(spec=SlotDefinition)
mock_slot.id = slot_id
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_slot
mock_session.execute.return_value = mock_result
success = await service.delete_slot_definition("test-tenant", str(slot_id))
assert success is True
mock_session.delete.assert_called_once()
mock_session.flush.assert_called_once()
@pytest.mark.asyncio
async def test_delete_slot_definition_not_found(self, service, mock_session):
"""[AC-MRS-16] 删除不存在的槽位定义"""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
success = await service.delete_slot_definition("test-tenant", str(uuid.uuid4()))
assert success is False
@pytest.mark.asyncio
async def test_get_slot_definition_with_field(self, service, mock_session):
"""获取槽位定义及关联字段信息"""
slot_id = uuid.uuid4()
mock_slot = MagicMock(spec=SlotDefinition)
mock_slot.id = slot_id
mock_slot.tenant_id = "test-tenant"
mock_slot.slot_key = "grade"
mock_slot.type = "string"
mock_slot.required = True
mock_slot.extract_strategy = "llm"
mock_slot.validation_rule = None
mock_slot.ask_back_prompt = "请输入年级"
mock_slot.default_value = None
mock_slot.linked_field_id = None
mock_slot.created_at = None
mock_slot.updated_at = None
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_slot
mock_session.execute.return_value = mock_result
result = await service.get_slot_definition_with_field("test-tenant", str(slot_id))
assert result is not None
assert result["slot_key"] == "grade"
assert result["linked_field"] is None