ai-robot-core/ai-service/app/services/slot_definition_service.py

462 lines
15 KiB
Python
Raw Normal View History

"""
Slot Definition Service.
[AC-MRS-07, AC-MRS-08] 槽位定义管理服务
[AC-MRS-07-UPGRADE] 支持提取策略链 extract_strategies
"""
import logging
import re
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.entities import (
SlotDefinition,
SlotDefinitionCreate,
SlotDefinitionUpdate,
MetadataFieldDefinition,
MetadataFieldType,
ExtractStrategy,
)
logger = logging.getLogger(__name__)
class SlotDefinitionService:
"""
[AC-MRS-07, AC-MRS-08] 槽位定义服务
[AC-MRS-07-UPGRADE] 支持提取策略链管理
管理独立的槽位定义模型与元数据字段解耦但可复用
"""
SLOT_KEY_PATTERN = re.compile(r"^[a-z][a-z0-9_]*$")
VALID_TYPES = ["string", "number", "boolean", "enum", "array_enum"]
VALID_EXTRACT_STRATEGIES = ["rule", "llm", "user_input"]
def __init__(self, session: AsyncSession):
self._session = session
async def list_slot_definitions(
self,
tenant_id: str,
required: bool | None = None,
) -> list[SlotDefinition]:
"""
列出租户所有槽位定义
Args:
tenant_id: 租户 ID
required: 按是否必填过滤
Returns:
SlotDefinition 列表
"""
stmt = select(SlotDefinition).where(
SlotDefinition.tenant_id == tenant_id,
)
if required is not None:
stmt = stmt.where(SlotDefinition.required == required)
stmt = stmt.order_by(SlotDefinition.created_at.desc())
result = await self._session.execute(stmt)
return list(result.scalars().all())
async def get_slot_definition(
self,
tenant_id: str,
slot_id: str,
) -> SlotDefinition | None:
"""
获取单个槽位定义
Args:
tenant_id: 租户 ID
slot_id: 槽位定义 ID
Returns:
SlotDefinition None
"""
try:
slot_uuid = uuid.UUID(slot_id)
except ValueError:
return None
stmt = select(SlotDefinition).where(
SlotDefinition.tenant_id == tenant_id,
SlotDefinition.id == slot_uuid,
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def get_slot_definition_by_key(
self,
tenant_id: str,
slot_key: str,
) -> SlotDefinition | None:
"""
通过 slot_key 获取槽位定义
Args:
tenant_id: 租户 ID
slot_key: 槽位键名
Returns:
SlotDefinition None
"""
stmt = select(SlotDefinition).where(
SlotDefinition.tenant_id == tenant_id,
SlotDefinition.slot_key == slot_key,
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
def _validate_strategies(self, strategies: list[str] | None) -> tuple[bool, str]:
"""
[AC-MRS-07-UPGRADE] 校验提取策略链的有效性
Args:
strategies: 策略链列表
Returns:
Tuple of (是否有效, 错误信息)
"""
if strategies is None:
return True, ""
if not isinstance(strategies, list):
return False, "extract_strategies 必须是数组类型"
if len(strategies) == 0:
return False, "提取策略链不能为空数组"
# 校验不允许重复策略
if len(strategies) != len(set(strategies)):
return False, "提取策略链中不允许重复的策略"
# 校验策略值有效
invalid = [s for s in strategies if s not in self.VALID_EXTRACT_STRATEGIES]
if invalid:
return False, f"无效的提取策略: {invalid},有效值为: {self.VALID_EXTRACT_STRATEGIES}"
return True, ""
def _normalize_strategies(
self,
extract_strategies: list[str] | None,
extract_strategy: str | None,
) -> list[str] | None:
"""
[AC-MRS-07-UPGRADE] 规范化提取策略
优先使用 extract_strategies如果不存在则使用 extract_strategy
Args:
extract_strategies: 策略链新字段
extract_strategy: 单策略旧字段兼容
Returns:
规范化后的策略链或 None
"""
if extract_strategies is not None:
return extract_strategies
if extract_strategy:
return [extract_strategy]
return None
async def create_slot_definition(
self,
tenant_id: str,
slot_create: SlotDefinitionCreate,
) -> SlotDefinition:
"""
[AC-MRS-07, AC-MRS-08] 创建槽位定义
[AC-MRS-07-UPGRADE] 支持提取策略链
Args:
tenant_id: 租户 ID
slot_create: 创建数据
Returns:
创建的 SlotDefinition
Raises:
ValueError: 如果 slot_key 已存在或参数无效
"""
if not self.SLOT_KEY_PATTERN.match(slot_create.slot_key):
raise ValueError(
f"slot_key '{slot_create.slot_key}' 格式不正确,"
"必须以小写字母开头,仅允许小写字母、数字和下划线"
)
existing = await self.get_slot_definition_by_key(tenant_id, slot_create.slot_key)
if existing:
raise ValueError(f"slot_key '{slot_create.slot_key}' 已存在")
if slot_create.type not in self.VALID_TYPES:
raise ValueError(
f"无效的槽位类型 '{slot_create.type}'"
f"有效类型为: {self.VALID_TYPES}"
)
# [AC-MRS-07-UPGRADE] 规范化并校验提取策略链
strategies = self._normalize_strategies(
slot_create.extract_strategies,
slot_create.extract_strategy
)
if strategies is not None:
is_valid, error_msg = self._validate_strategies(strategies)
if not is_valid:
raise ValueError(f"提取策略链校验失败: {error_msg}")
linked_field = None
if slot_create.linked_field_id:
linked_field = await self._get_linked_field(slot_create.linked_field_id)
if not linked_field:
raise ValueError(
f"[AC-MRS-08] 关联的元数据字段 '{slot_create.linked_field_id}' 不存在"
)
# [AC-MRS-07-UPGRADE] 确定要保存的旧字段值
# 如果前端提交了 extract_strategies则使用第一个作为旧字段值
old_strategy = slot_create.extract_strategy
if not old_strategy and strategies and len(strategies) > 0:
old_strategy = strategies[0]
slot = SlotDefinition(
tenant_id=tenant_id,
slot_key=slot_create.slot_key,
display_name=slot_create.display_name,
description=slot_create.description,
type=slot_create.type,
required=slot_create.required,
# [AC-MRS-07-UPGRADE] 同时保存新旧字段
extract_strategy=old_strategy,
extract_strategies=strategies,
validation_rule=slot_create.validation_rule,
ask_back_prompt=slot_create.ask_back_prompt,
default_value=slot_create.default_value,
linked_field_id=uuid.UUID(slot_create.linked_field_id) if slot_create.linked_field_id else None,
)
self._session.add(slot)
await self._session.flush()
logger.info(
f"[AC-MRS-07] Created slot definition: tenant={tenant_id}, "
f"slot_key={slot.slot_key}, required={slot.required}, "
f"strategies={strategies}, "
f"linked_field_id={slot.linked_field_id}"
)
return slot
async def update_slot_definition(
self,
tenant_id: str,
slot_id: str,
slot_update: SlotDefinitionUpdate,
) -> SlotDefinition | None:
"""
更新槽位定义
[AC-MRS-07-UPGRADE] 支持提取策略链更新
Args:
tenant_id: 租户 ID
slot_id: 槽位定义 ID
slot_update: 更新数据
Returns:
更新后的 SlotDefinition None
"""
slot = await self.get_slot_definition(tenant_id, slot_id)
if not slot:
return None
if slot_update.display_name is not None:
slot.display_name = slot_update.display_name
if slot_update.description is not None:
slot.description = slot_update.description
if slot_update.type is not None:
if slot_update.type not in self.VALID_TYPES:
raise ValueError(
f"无效的槽位类型 '{slot_update.type}'"
f"有效类型为: {self.VALID_TYPES}"
)
slot.type = slot_update.type
if slot_update.required is not None:
slot.required = slot_update.required
# [AC-MRS-07-UPGRADE] 处理提取策略链更新
# 如果传入了 extract_strategies 或 extract_strategy则更新
if slot_update.extract_strategies is not None or slot_update.extract_strategy is not None:
strategies = self._normalize_strategies(
slot_update.extract_strategies,
slot_update.extract_strategy
)
if strategies is not None:
is_valid, error_msg = self._validate_strategies(strategies)
if not is_valid:
raise ValueError(f"提取策略链校验失败: {error_msg}")
# [AC-MRS-07-UPGRADE] 同时更新新旧字段
slot.extract_strategies = strategies
# 如果前端提交了 extract_strategy则使用它否则使用策略链的第一个
if slot_update.extract_strategy is not None:
slot.extract_strategy = slot_update.extract_strategy
elif strategies and len(strategies) > 0:
slot.extract_strategy = strategies[0]
else:
slot.extract_strategy = None
if slot_update.validation_rule is not None:
slot.validation_rule = slot_update.validation_rule
if slot_update.ask_back_prompt is not None:
slot.ask_back_prompt = slot_update.ask_back_prompt
if slot_update.default_value is not None:
slot.default_value = slot_update.default_value
if slot_update.linked_field_id is not None:
if slot_update.linked_field_id:
linked_field = await self._get_linked_field(slot_update.linked_field_id)
if not linked_field:
raise ValueError(
f"[AC-MRS-08] 关联的元数据字段 '{slot_update.linked_field_id}' 不存在"
)
slot.linked_field_id = uuid.UUID(slot_update.linked_field_id)
else:
slot.linked_field_id = None
slot.updated_at = datetime.utcnow()
await self._session.flush()
logger.info(
f"[AC-MRS-07] Updated slot definition: tenant={tenant_id}, "
f"slot_id={slot_id}, strategies={slot.extract_strategies}"
)
return slot
async def delete_slot_definition(
self,
tenant_id: str,
slot_id: str,
) -> bool:
"""
[AC-MRS-16] 删除槽位定义
Args:
tenant_id: 租户 ID
slot_id: 槽位定义 ID
Returns:
是否删除成功
"""
slot = await self.get_slot_definition(tenant_id, slot_id)
if not slot:
return False
await self._session.delete(slot)
await self._session.flush()
logger.info(
f"[AC-MRS-16] Deleted slot definition: tenant={tenant_id}, "
f"slot_id={slot_id}"
)
return True
async def _get_linked_field(
self,
field_id: str,
) -> MetadataFieldDefinition | None:
"""
获取关联的元数据字段
Args:
field_id: 字段 ID
Returns:
MetadataFieldDefinition None
"""
try:
field_uuid = uuid.UUID(field_id)
except ValueError:
return None
stmt = select(MetadataFieldDefinition).where(
MetadataFieldDefinition.id == field_uuid,
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def get_slot_definition_with_field(
self,
tenant_id: str,
slot_id: str,
) -> dict[str, Any] | None:
"""
获取槽位定义及其关联字段信息
Args:
tenant_id: 租户 ID
slot_id: 槽位定义 ID
Returns:
包含槽位定义和关联字段的字典
"""
slot = await self.get_slot_definition(tenant_id, slot_id)
if not slot:
return None
result = {
"id": str(slot.id),
"tenant_id": slot.tenant_id,
"slot_key": slot.slot_key,
"display_name": slot.display_name,
"description": slot.description,
"type": slot.type,
"required": slot.required,
# [AC-MRS-07-UPGRADE] 返回新旧字段
"extract_strategy": slot.extract_strategy,
"extract_strategies": slot.extract_strategies,
"validation_rule": slot.validation_rule,
"ask_back_prompt": slot.ask_back_prompt,
"default_value": slot.default_value,
"linked_field_id": str(slot.linked_field_id) if slot.linked_field_id else None,
"created_at": slot.created_at.isoformat() if slot.created_at else None,
"updated_at": slot.updated_at.isoformat() if slot.updated_at else None,
"linked_field": None,
}
if slot.linked_field_id:
linked_field = await self._get_linked_field(str(slot.linked_field_id))
if linked_field:
result["linked_field"] = {
"id": str(linked_field.id),
"field_key": linked_field.field_key,
"label": linked_field.label,
"type": linked_field.type,
"required": linked_field.required,
"options": linked_field.options,
"default_value": linked_field.default_value,
"scope": linked_field.scope,
"is_filterable": linked_field.is_filterable,
"is_rank_feature": linked_field.is_rank_feature,
"field_roles": linked_field.field_roles,
"status": linked_field.status,
}
return result