ai-robot-core/ai-service/app/services/slot_definition_service.py

365 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Slot Definition Service.
[AC-MRS-07, AC-MRS-08] 槽位定义管理服务
"""
import logging
import re
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.entities import (
SlotDefinition,
SlotDefinitionCreate,
SlotDefinitionUpdate,
MetadataFieldDefinition,
MetadataFieldType,
ExtractStrategy,
)
logger = logging.getLogger(__name__)
class SlotDefinitionService:
"""
[AC-MRS-07, AC-MRS-08] 槽位定义服务
管理独立的槽位定义模型,与元数据字段解耦但可复用
"""
SLOT_KEY_PATTERN = re.compile(r"^[a-z][a-z0-9_]*$")
VALID_TYPES = ["string", "number", "boolean", "enum", "array_enum"]
VALID_EXTRACT_STRATEGIES = ["rule", "llm", "user_input"]
def __init__(self, session: AsyncSession):
self._session = session
async def list_slot_definitions(
self,
tenant_id: str,
required: bool | None = None,
) -> list[SlotDefinition]:
"""
列出租户所有槽位定义
Args:
tenant_id: 租户 ID
required: 按是否必填过滤
Returns:
SlotDefinition 列表
"""
stmt = select(SlotDefinition).where(
SlotDefinition.tenant_id == tenant_id,
)
if required is not None:
stmt = stmt.where(SlotDefinition.required == required)
stmt = stmt.order_by(SlotDefinition.created_at.desc())
result = await self._session.execute(stmt)
return list(result.scalars().all())
async def get_slot_definition(
self,
tenant_id: str,
slot_id: str,
) -> SlotDefinition | None:
"""
获取单个槽位定义
Args:
tenant_id: 租户 ID
slot_id: 槽位定义 ID
Returns:
SlotDefinition 或 None
"""
try:
slot_uuid = uuid.UUID(slot_id)
except ValueError:
return None
stmt = select(SlotDefinition).where(
SlotDefinition.tenant_id == tenant_id,
SlotDefinition.id == slot_uuid,
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def get_slot_definition_by_key(
self,
tenant_id: str,
slot_key: str,
) -> SlotDefinition | None:
"""
通过 slot_key 获取槽位定义
Args:
tenant_id: 租户 ID
slot_key: 槽位键名
Returns:
SlotDefinition 或 None
"""
stmt = select(SlotDefinition).where(
SlotDefinition.tenant_id == tenant_id,
SlotDefinition.slot_key == slot_key,
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def create_slot_definition(
self,
tenant_id: str,
slot_create: SlotDefinitionCreate,
) -> SlotDefinition:
"""
[AC-MRS-07, AC-MRS-08] 创建槽位定义
Args:
tenant_id: 租户 ID
slot_create: 创建数据
Returns:
创建的 SlotDefinition
Raises:
ValueError: 如果 slot_key 已存在或参数无效
"""
if not self.SLOT_KEY_PATTERN.match(slot_create.slot_key):
raise ValueError(
f"slot_key '{slot_create.slot_key}' 格式不正确,"
"必须以小写字母开头,仅允许小写字母、数字和下划线"
)
existing = await self.get_slot_definition_by_key(tenant_id, slot_create.slot_key)
if existing:
raise ValueError(f"slot_key '{slot_create.slot_key}' 已存在")
if slot_create.type not in self.VALID_TYPES:
raise ValueError(
f"无效的槽位类型 '{slot_create.type}'"
f"有效类型为: {self.VALID_TYPES}"
)
if slot_create.extract_strategy and slot_create.extract_strategy not in self.VALID_EXTRACT_STRATEGIES:
raise ValueError(
f"无效的提取策略 '{slot_create.extract_strategy}'"
f"有效策略为: {self.VALID_EXTRACT_STRATEGIES}"
)
linked_field = None
if slot_create.linked_field_id:
linked_field = await self._get_linked_field(slot_create.linked_field_id)
if not linked_field:
raise ValueError(
f"[AC-MRS-08] 关联的元数据字段 '{slot_create.linked_field_id}' 不存在"
)
slot = SlotDefinition(
tenant_id=tenant_id,
slot_key=slot_create.slot_key,
type=slot_create.type,
required=slot_create.required,
extract_strategy=slot_create.extract_strategy,
validation_rule=slot_create.validation_rule,
ask_back_prompt=slot_create.ask_back_prompt,
default_value=slot_create.default_value,
linked_field_id=uuid.UUID(slot_create.linked_field_id) if slot_create.linked_field_id else None,
)
self._session.add(slot)
await self._session.flush()
logger.info(
f"[AC-MRS-07] Created slot definition: tenant={tenant_id}, "
f"slot_key={slot.slot_key}, required={slot.required}, "
f"linked_field_id={slot.linked_field_id}"
)
return slot
async def update_slot_definition(
self,
tenant_id: str,
slot_id: str,
slot_update: SlotDefinitionUpdate,
) -> SlotDefinition | None:
"""
更新槽位定义
Args:
tenant_id: 租户 ID
slot_id: 槽位定义 ID
slot_update: 更新数据
Returns:
更新后的 SlotDefinition 或 None
"""
slot = await self.get_slot_definition(tenant_id, slot_id)
if not slot:
return None
if slot_update.type is not None:
if slot_update.type not in self.VALID_TYPES:
raise ValueError(
f"无效的槽位类型 '{slot_update.type}'"
f"有效类型为: {self.VALID_TYPES}"
)
slot.type = slot_update.type
if slot_update.required is not None:
slot.required = slot_update.required
if slot_update.extract_strategy is not None:
if slot_update.extract_strategy not in self.VALID_EXTRACT_STRATEGIES:
raise ValueError(
f"无效的提取策略 '{slot_update.extract_strategy}'"
f"有效策略为: {self.VALID_EXTRACT_STRATEGIES}"
)
slot.extract_strategy = slot_update.extract_strategy
if slot_update.validation_rule is not None:
slot.validation_rule = slot_update.validation_rule
if slot_update.ask_back_prompt is not None:
slot.ask_back_prompt = slot_update.ask_back_prompt
if slot_update.default_value is not None:
slot.default_value = slot_update.default_value
if slot_update.linked_field_id is not None:
if slot_update.linked_field_id:
linked_field = await self._get_linked_field(slot_update.linked_field_id)
if not linked_field:
raise ValueError(
f"[AC-MRS-08] 关联的元数据字段 '{slot_update.linked_field_id}' 不存在"
)
slot.linked_field_id = uuid.UUID(slot_update.linked_field_id)
else:
slot.linked_field_id = None
slot.updated_at = datetime.utcnow()
await self._session.flush()
logger.info(
f"[AC-MRS-07] Updated slot definition: tenant={tenant_id}, "
f"slot_id={slot_id}"
)
return slot
async def delete_slot_definition(
self,
tenant_id: str,
slot_id: str,
) -> bool:
"""
[AC-MRS-16] 删除槽位定义
Args:
tenant_id: 租户 ID
slot_id: 槽位定义 ID
Returns:
是否删除成功
"""
slot = await self.get_slot_definition(tenant_id, slot_id)
if not slot:
return False
await self._session.delete(slot)
await self._session.flush()
logger.info(
f"[AC-MRS-16] Deleted slot definition: tenant={tenant_id}, "
f"slot_id={slot_id}"
)
return True
async def _get_linked_field(
self,
field_id: str,
) -> MetadataFieldDefinition | None:
"""
获取关联的元数据字段
Args:
field_id: 字段 ID
Returns:
MetadataFieldDefinition 或 None
"""
try:
field_uuid = uuid.UUID(field_id)
except ValueError:
return None
stmt = select(MetadataFieldDefinition).where(
MetadataFieldDefinition.id == field_uuid,
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def get_slot_definition_with_field(
self,
tenant_id: str,
slot_id: str,
) -> dict[str, Any] | None:
"""
获取槽位定义及其关联字段信息
Args:
tenant_id: 租户 ID
slot_id: 槽位定义 ID
Returns:
包含槽位定义和关联字段的字典
"""
slot = await self.get_slot_definition(tenant_id, slot_id)
if not slot:
return None
result = {
"id": str(slot.id),
"tenant_id": slot.tenant_id,
"slot_key": slot.slot_key,
"type": slot.type,
"required": slot.required,
"extract_strategy": slot.extract_strategy,
"validation_rule": slot.validation_rule,
"ask_back_prompt": slot.ask_back_prompt,
"default_value": slot.default_value,
"linked_field_id": str(slot.linked_field_id) if slot.linked_field_id else None,
"created_at": slot.created_at.isoformat() if slot.created_at else None,
"updated_at": slot.updated_at.isoformat() if slot.updated_at else None,
"linked_field": None,
}
if slot.linked_field_id:
linked_field = await self._get_linked_field(str(slot.linked_field_id))
if linked_field:
result["linked_field"] = {
"id": str(linked_field.id),
"field_key": linked_field.field_key,
"label": linked_field.label,
"type": linked_field.type,
"required": linked_field.required,
"options": linked_field.options,
"default_value": linked_field.default_value,
"scope": linked_field.scope,
"is_filterable": linked_field.is_filterable,
"is_rank_feature": linked_field.is_rank_feature,
"field_roles": linked_field.field_roles,
"status": linked_field.status,
}
return result