365 lines
12 KiB
Python
365 lines
12 KiB
Python
"""
|
||
Slot Definition Service.
|
||
[AC-MRS-07, AC-MRS-08] 槽位定义管理服务
|
||
"""
|
||
|
||
import logging
|
||
import re
|
||
import uuid
|
||
from datetime import datetime
|
||
from typing import Any
|
||
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.models.entities import (
|
||
SlotDefinition,
|
||
SlotDefinitionCreate,
|
||
SlotDefinitionUpdate,
|
||
MetadataFieldDefinition,
|
||
MetadataFieldType,
|
||
ExtractStrategy,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class SlotDefinitionService:
|
||
"""
|
||
[AC-MRS-07, AC-MRS-08] 槽位定义服务
|
||
|
||
管理独立的槽位定义模型,与元数据字段解耦但可复用
|
||
"""
|
||
|
||
SLOT_KEY_PATTERN = re.compile(r"^[a-z][a-z0-9_]*$")
|
||
VALID_TYPES = ["string", "number", "boolean", "enum", "array_enum"]
|
||
VALID_EXTRACT_STRATEGIES = ["rule", "llm", "user_input"]
|
||
|
||
def __init__(self, session: AsyncSession):
|
||
self._session = session
|
||
|
||
async def list_slot_definitions(
|
||
self,
|
||
tenant_id: str,
|
||
required: bool | None = None,
|
||
) -> list[SlotDefinition]:
|
||
"""
|
||
列出租户所有槽位定义
|
||
|
||
Args:
|
||
tenant_id: 租户 ID
|
||
required: 按是否必填过滤
|
||
|
||
Returns:
|
||
SlotDefinition 列表
|
||
"""
|
||
stmt = select(SlotDefinition).where(
|
||
SlotDefinition.tenant_id == tenant_id,
|
||
)
|
||
|
||
if required is not None:
|
||
stmt = stmt.where(SlotDefinition.required == required)
|
||
|
||
stmt = stmt.order_by(SlotDefinition.created_at.desc())
|
||
|
||
result = await self._session.execute(stmt)
|
||
return list(result.scalars().all())
|
||
|
||
async def get_slot_definition(
|
||
self,
|
||
tenant_id: str,
|
||
slot_id: str,
|
||
) -> SlotDefinition | None:
|
||
"""
|
||
获取单个槽位定义
|
||
|
||
Args:
|
||
tenant_id: 租户 ID
|
||
slot_id: 槽位定义 ID
|
||
|
||
Returns:
|
||
SlotDefinition 或 None
|
||
"""
|
||
try:
|
||
slot_uuid = uuid.UUID(slot_id)
|
||
except ValueError:
|
||
return None
|
||
|
||
stmt = select(SlotDefinition).where(
|
||
SlotDefinition.tenant_id == tenant_id,
|
||
SlotDefinition.id == slot_uuid,
|
||
)
|
||
result = await self._session.execute(stmt)
|
||
return result.scalar_one_or_none()
|
||
|
||
async def get_slot_definition_by_key(
|
||
self,
|
||
tenant_id: str,
|
||
slot_key: str,
|
||
) -> SlotDefinition | None:
|
||
"""
|
||
通过 slot_key 获取槽位定义
|
||
|
||
Args:
|
||
tenant_id: 租户 ID
|
||
slot_key: 槽位键名
|
||
|
||
Returns:
|
||
SlotDefinition 或 None
|
||
"""
|
||
stmt = select(SlotDefinition).where(
|
||
SlotDefinition.tenant_id == tenant_id,
|
||
SlotDefinition.slot_key == slot_key,
|
||
)
|
||
result = await self._session.execute(stmt)
|
||
return result.scalar_one_or_none()
|
||
|
||
async def create_slot_definition(
|
||
self,
|
||
tenant_id: str,
|
||
slot_create: SlotDefinitionCreate,
|
||
) -> SlotDefinition:
|
||
"""
|
||
[AC-MRS-07, AC-MRS-08] 创建槽位定义
|
||
|
||
Args:
|
||
tenant_id: 租户 ID
|
||
slot_create: 创建数据
|
||
|
||
Returns:
|
||
创建的 SlotDefinition
|
||
|
||
Raises:
|
||
ValueError: 如果 slot_key 已存在或参数无效
|
||
"""
|
||
if not self.SLOT_KEY_PATTERN.match(slot_create.slot_key):
|
||
raise ValueError(
|
||
f"slot_key '{slot_create.slot_key}' 格式不正确,"
|
||
"必须以小写字母开头,仅允许小写字母、数字和下划线"
|
||
)
|
||
|
||
existing = await self.get_slot_definition_by_key(tenant_id, slot_create.slot_key)
|
||
if existing:
|
||
raise ValueError(f"slot_key '{slot_create.slot_key}' 已存在")
|
||
|
||
if slot_create.type not in self.VALID_TYPES:
|
||
raise ValueError(
|
||
f"无效的槽位类型 '{slot_create.type}',"
|
||
f"有效类型为: {self.VALID_TYPES}"
|
||
)
|
||
|
||
if slot_create.extract_strategy and slot_create.extract_strategy not in self.VALID_EXTRACT_STRATEGIES:
|
||
raise ValueError(
|
||
f"无效的提取策略 '{slot_create.extract_strategy}',"
|
||
f"有效策略为: {self.VALID_EXTRACT_STRATEGIES}"
|
||
)
|
||
|
||
linked_field = None
|
||
if slot_create.linked_field_id:
|
||
linked_field = await self._get_linked_field(slot_create.linked_field_id)
|
||
if not linked_field:
|
||
raise ValueError(
|
||
f"[AC-MRS-08] 关联的元数据字段 '{slot_create.linked_field_id}' 不存在"
|
||
)
|
||
|
||
slot = SlotDefinition(
|
||
tenant_id=tenant_id,
|
||
slot_key=slot_create.slot_key,
|
||
type=slot_create.type,
|
||
required=slot_create.required,
|
||
extract_strategy=slot_create.extract_strategy,
|
||
validation_rule=slot_create.validation_rule,
|
||
ask_back_prompt=slot_create.ask_back_prompt,
|
||
default_value=slot_create.default_value,
|
||
linked_field_id=uuid.UUID(slot_create.linked_field_id) if slot_create.linked_field_id else None,
|
||
)
|
||
|
||
self._session.add(slot)
|
||
await self._session.flush()
|
||
|
||
logger.info(
|
||
f"[AC-MRS-07] Created slot definition: tenant={tenant_id}, "
|
||
f"slot_key={slot.slot_key}, required={slot.required}, "
|
||
f"linked_field_id={slot.linked_field_id}"
|
||
)
|
||
|
||
return slot
|
||
|
||
async def update_slot_definition(
|
||
self,
|
||
tenant_id: str,
|
||
slot_id: str,
|
||
slot_update: SlotDefinitionUpdate,
|
||
) -> SlotDefinition | None:
|
||
"""
|
||
更新槽位定义
|
||
|
||
Args:
|
||
tenant_id: 租户 ID
|
||
slot_id: 槽位定义 ID
|
||
slot_update: 更新数据
|
||
|
||
Returns:
|
||
更新后的 SlotDefinition 或 None
|
||
"""
|
||
slot = await self.get_slot_definition(tenant_id, slot_id)
|
||
if not slot:
|
||
return None
|
||
|
||
if slot_update.type is not None:
|
||
if slot_update.type not in self.VALID_TYPES:
|
||
raise ValueError(
|
||
f"无效的槽位类型 '{slot_update.type}',"
|
||
f"有效类型为: {self.VALID_TYPES}"
|
||
)
|
||
slot.type = slot_update.type
|
||
|
||
if slot_update.required is not None:
|
||
slot.required = slot_update.required
|
||
|
||
if slot_update.extract_strategy is not None:
|
||
if slot_update.extract_strategy not in self.VALID_EXTRACT_STRATEGIES:
|
||
raise ValueError(
|
||
f"无效的提取策略 '{slot_update.extract_strategy}',"
|
||
f"有效策略为: {self.VALID_EXTRACT_STRATEGIES}"
|
||
)
|
||
slot.extract_strategy = slot_update.extract_strategy
|
||
|
||
if slot_update.validation_rule is not None:
|
||
slot.validation_rule = slot_update.validation_rule
|
||
|
||
if slot_update.ask_back_prompt is not None:
|
||
slot.ask_back_prompt = slot_update.ask_back_prompt
|
||
|
||
if slot_update.default_value is not None:
|
||
slot.default_value = slot_update.default_value
|
||
|
||
if slot_update.linked_field_id is not None:
|
||
if slot_update.linked_field_id:
|
||
linked_field = await self._get_linked_field(slot_update.linked_field_id)
|
||
if not linked_field:
|
||
raise ValueError(
|
||
f"[AC-MRS-08] 关联的元数据字段 '{slot_update.linked_field_id}' 不存在"
|
||
)
|
||
slot.linked_field_id = uuid.UUID(slot_update.linked_field_id)
|
||
else:
|
||
slot.linked_field_id = None
|
||
|
||
slot.updated_at = datetime.utcnow()
|
||
await self._session.flush()
|
||
|
||
logger.info(
|
||
f"[AC-MRS-07] Updated slot definition: tenant={tenant_id}, "
|
||
f"slot_id={slot_id}"
|
||
)
|
||
|
||
return slot
|
||
|
||
async def delete_slot_definition(
|
||
self,
|
||
tenant_id: str,
|
||
slot_id: str,
|
||
) -> bool:
|
||
"""
|
||
[AC-MRS-16] 删除槽位定义
|
||
|
||
Args:
|
||
tenant_id: 租户 ID
|
||
slot_id: 槽位定义 ID
|
||
|
||
Returns:
|
||
是否删除成功
|
||
"""
|
||
slot = await self.get_slot_definition(tenant_id, slot_id)
|
||
if not slot:
|
||
return False
|
||
|
||
await self._session.delete(slot)
|
||
await self._session.flush()
|
||
|
||
logger.info(
|
||
f"[AC-MRS-16] Deleted slot definition: tenant={tenant_id}, "
|
||
f"slot_id={slot_id}"
|
||
)
|
||
|
||
return True
|
||
|
||
async def _get_linked_field(
|
||
self,
|
||
field_id: str,
|
||
) -> MetadataFieldDefinition | None:
|
||
"""
|
||
获取关联的元数据字段
|
||
|
||
Args:
|
||
field_id: 字段 ID
|
||
|
||
Returns:
|
||
MetadataFieldDefinition 或 None
|
||
"""
|
||
try:
|
||
field_uuid = uuid.UUID(field_id)
|
||
except ValueError:
|
||
return None
|
||
|
||
stmt = select(MetadataFieldDefinition).where(
|
||
MetadataFieldDefinition.id == field_uuid,
|
||
)
|
||
result = await self._session.execute(stmt)
|
||
return result.scalar_one_or_none()
|
||
|
||
async def get_slot_definition_with_field(
|
||
self,
|
||
tenant_id: str,
|
||
slot_id: str,
|
||
) -> dict[str, Any] | None:
|
||
"""
|
||
获取槽位定义及其关联字段信息
|
||
|
||
Args:
|
||
tenant_id: 租户 ID
|
||
slot_id: 槽位定义 ID
|
||
|
||
Returns:
|
||
包含槽位定义和关联字段的字典
|
||
"""
|
||
slot = await self.get_slot_definition(tenant_id, slot_id)
|
||
if not slot:
|
||
return None
|
||
|
||
result = {
|
||
"id": str(slot.id),
|
||
"tenant_id": slot.tenant_id,
|
||
"slot_key": slot.slot_key,
|
||
"type": slot.type,
|
||
"required": slot.required,
|
||
"extract_strategy": slot.extract_strategy,
|
||
"validation_rule": slot.validation_rule,
|
||
"ask_back_prompt": slot.ask_back_prompt,
|
||
"default_value": slot.default_value,
|
||
"linked_field_id": str(slot.linked_field_id) if slot.linked_field_id else None,
|
||
"created_at": slot.created_at.isoformat() if slot.created_at else None,
|
||
"updated_at": slot.updated_at.isoformat() if slot.updated_at else None,
|
||
"linked_field": None,
|
||
}
|
||
|
||
if slot.linked_field_id:
|
||
linked_field = await self._get_linked_field(str(slot.linked_field_id))
|
||
if linked_field:
|
||
result["linked_field"] = {
|
||
"id": str(linked_field.id),
|
||
"field_key": linked_field.field_key,
|
||
"label": linked_field.label,
|
||
"type": linked_field.type,
|
||
"required": linked_field.required,
|
||
"options": linked_field.options,
|
||
"default_value": linked_field.default_value,
|
||
"scope": linked_field.scope,
|
||
"is_filterable": linked_field.is_filterable,
|
||
"is_rank_feature": linked_field.is_rank_feature,
|
||
"field_roles": linked_field.field_roles,
|
||
"status": linked_field.status,
|
||
}
|
||
|
||
return result
|