462 lines
15 KiB
Python
462 lines
15 KiB
Python
"""
|
||
Slot Definition Service.
|
||
[AC-MRS-07, AC-MRS-08] 槽位定义管理服务
|
||
[AC-MRS-07-UPGRADE] 支持提取策略链 extract_strategies
|
||
"""
|
||
|
||
import logging
|
||
import re
|
||
import uuid
|
||
from datetime import datetime
|
||
from typing import Any
|
||
|
||
from sqlalchemy import select
|
||
from sqlalchemy.ext.asyncio import AsyncSession
|
||
|
||
from app.models.entities import (
|
||
SlotDefinition,
|
||
SlotDefinitionCreate,
|
||
SlotDefinitionUpdate,
|
||
MetadataFieldDefinition,
|
||
MetadataFieldType,
|
||
ExtractStrategy,
|
||
)
|
||
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class SlotDefinitionService:
|
||
"""
|
||
[AC-MRS-07, AC-MRS-08] 槽位定义服务
|
||
[AC-MRS-07-UPGRADE] 支持提取策略链管理
|
||
|
||
管理独立的槽位定义模型,与元数据字段解耦但可复用
|
||
"""
|
||
|
||
SLOT_KEY_PATTERN = re.compile(r"^[a-z][a-z0-9_]*$")
|
||
VALID_TYPES = ["string", "number", "boolean", "enum", "array_enum"]
|
||
VALID_EXTRACT_STRATEGIES = ["rule", "llm", "user_input"]
|
||
|
||
def __init__(self, session: AsyncSession):
|
||
self._session = session
|
||
|
||
async def list_slot_definitions(
|
||
self,
|
||
tenant_id: str,
|
||
required: bool | None = None,
|
||
) -> list[SlotDefinition]:
|
||
"""
|
||
列出租户所有槽位定义
|
||
|
||
Args:
|
||
tenant_id: 租户 ID
|
||
required: 按是否必填过滤
|
||
|
||
Returns:
|
||
SlotDefinition 列表
|
||
"""
|
||
stmt = select(SlotDefinition).where(
|
||
SlotDefinition.tenant_id == tenant_id,
|
||
)
|
||
|
||
if required is not None:
|
||
stmt = stmt.where(SlotDefinition.required == required)
|
||
|
||
stmt = stmt.order_by(SlotDefinition.created_at.desc())
|
||
|
||
result = await self._session.execute(stmt)
|
||
return list(result.scalars().all())
|
||
|
||
async def get_slot_definition(
|
||
self,
|
||
tenant_id: str,
|
||
slot_id: str,
|
||
) -> SlotDefinition | None:
|
||
"""
|
||
获取单个槽位定义
|
||
|
||
Args:
|
||
tenant_id: 租户 ID
|
||
slot_id: 槽位定义 ID
|
||
|
||
Returns:
|
||
SlotDefinition 或 None
|
||
"""
|
||
try:
|
||
slot_uuid = uuid.UUID(slot_id)
|
||
except ValueError:
|
||
return None
|
||
|
||
stmt = select(SlotDefinition).where(
|
||
SlotDefinition.tenant_id == tenant_id,
|
||
SlotDefinition.id == slot_uuid,
|
||
)
|
||
result = await self._session.execute(stmt)
|
||
return result.scalar_one_or_none()
|
||
|
||
async def get_slot_definition_by_key(
|
||
self,
|
||
tenant_id: str,
|
||
slot_key: str,
|
||
) -> SlotDefinition | None:
|
||
"""
|
||
通过 slot_key 获取槽位定义
|
||
|
||
Args:
|
||
tenant_id: 租户 ID
|
||
slot_key: 槽位键名
|
||
|
||
Returns:
|
||
SlotDefinition 或 None
|
||
"""
|
||
stmt = select(SlotDefinition).where(
|
||
SlotDefinition.tenant_id == tenant_id,
|
||
SlotDefinition.slot_key == slot_key,
|
||
)
|
||
result = await self._session.execute(stmt)
|
||
return result.scalar_one_or_none()
|
||
|
||
def _validate_strategies(self, strategies: list[str] | None) -> tuple[bool, str]:
|
||
"""
|
||
[AC-MRS-07-UPGRADE] 校验提取策略链的有效性
|
||
|
||
Args:
|
||
strategies: 策略链列表
|
||
|
||
Returns:
|
||
Tuple of (是否有效, 错误信息)
|
||
"""
|
||
if strategies is None:
|
||
return True, ""
|
||
|
||
if not isinstance(strategies, list):
|
||
return False, "extract_strategies 必须是数组类型"
|
||
|
||
if len(strategies) == 0:
|
||
return False, "提取策略链不能为空数组"
|
||
|
||
# 校验不允许重复策略
|
||
if len(strategies) != len(set(strategies)):
|
||
return False, "提取策略链中不允许重复的策略"
|
||
|
||
# 校验策略值有效
|
||
invalid = [s for s in strategies if s not in self.VALID_EXTRACT_STRATEGIES]
|
||
if invalid:
|
||
return False, f"无效的提取策略: {invalid},有效值为: {self.VALID_EXTRACT_STRATEGIES}"
|
||
|
||
return True, ""
|
||
|
||
def _normalize_strategies(
|
||
self,
|
||
extract_strategies: list[str] | None,
|
||
extract_strategy: str | None,
|
||
) -> list[str] | None:
|
||
"""
|
||
[AC-MRS-07-UPGRADE] 规范化提取策略
|
||
优先使用 extract_strategies,如果不存在则使用 extract_strategy
|
||
|
||
Args:
|
||
extract_strategies: 策略链(新字段)
|
||
extract_strategy: 单策略(旧字段,兼容)
|
||
|
||
Returns:
|
||
规范化后的策略链或 None
|
||
"""
|
||
if extract_strategies is not None:
|
||
return extract_strategies
|
||
if extract_strategy:
|
||
return [extract_strategy]
|
||
return None
|
||
|
||
async def create_slot_definition(
|
||
self,
|
||
tenant_id: str,
|
||
slot_create: SlotDefinitionCreate,
|
||
) -> SlotDefinition:
|
||
"""
|
||
[AC-MRS-07, AC-MRS-08] 创建槽位定义
|
||
[AC-MRS-07-UPGRADE] 支持提取策略链
|
||
|
||
Args:
|
||
tenant_id: 租户 ID
|
||
slot_create: 创建数据
|
||
|
||
Returns:
|
||
创建的 SlotDefinition
|
||
|
||
Raises:
|
||
ValueError: 如果 slot_key 已存在或参数无效
|
||
"""
|
||
if not self.SLOT_KEY_PATTERN.match(slot_create.slot_key):
|
||
raise ValueError(
|
||
f"slot_key '{slot_create.slot_key}' 格式不正确,"
|
||
"必须以小写字母开头,仅允许小写字母、数字和下划线"
|
||
)
|
||
|
||
existing = await self.get_slot_definition_by_key(tenant_id, slot_create.slot_key)
|
||
if existing:
|
||
raise ValueError(f"slot_key '{slot_create.slot_key}' 已存在")
|
||
|
||
if slot_create.type not in self.VALID_TYPES:
|
||
raise ValueError(
|
||
f"无效的槽位类型 '{slot_create.type}',"
|
||
f"有效类型为: {self.VALID_TYPES}"
|
||
)
|
||
|
||
# [AC-MRS-07-UPGRADE] 规范化并校验提取策略链
|
||
strategies = self._normalize_strategies(
|
||
slot_create.extract_strategies,
|
||
slot_create.extract_strategy
|
||
)
|
||
|
||
if strategies is not None:
|
||
is_valid, error_msg = self._validate_strategies(strategies)
|
||
if not is_valid:
|
||
raise ValueError(f"提取策略链校验失败: {error_msg}")
|
||
|
||
linked_field = None
|
||
if slot_create.linked_field_id:
|
||
linked_field = await self._get_linked_field(slot_create.linked_field_id)
|
||
if not linked_field:
|
||
raise ValueError(
|
||
f"[AC-MRS-08] 关联的元数据字段 '{slot_create.linked_field_id}' 不存在"
|
||
)
|
||
|
||
# [AC-MRS-07-UPGRADE] 确定要保存的旧字段值
|
||
# 如果前端提交了 extract_strategies,则使用第一个作为旧字段值
|
||
old_strategy = slot_create.extract_strategy
|
||
if not old_strategy and strategies and len(strategies) > 0:
|
||
old_strategy = strategies[0]
|
||
|
||
slot = SlotDefinition(
|
||
tenant_id=tenant_id,
|
||
slot_key=slot_create.slot_key,
|
||
display_name=slot_create.display_name,
|
||
description=slot_create.description,
|
||
type=slot_create.type,
|
||
required=slot_create.required,
|
||
# [AC-MRS-07-UPGRADE] 同时保存新旧字段
|
||
extract_strategy=old_strategy,
|
||
extract_strategies=strategies,
|
||
validation_rule=slot_create.validation_rule,
|
||
ask_back_prompt=slot_create.ask_back_prompt,
|
||
default_value=slot_create.default_value,
|
||
linked_field_id=uuid.UUID(slot_create.linked_field_id) if slot_create.linked_field_id else None,
|
||
)
|
||
|
||
self._session.add(slot)
|
||
await self._session.flush()
|
||
|
||
logger.info(
|
||
f"[AC-MRS-07] Created slot definition: tenant={tenant_id}, "
|
||
f"slot_key={slot.slot_key}, required={slot.required}, "
|
||
f"strategies={strategies}, "
|
||
f"linked_field_id={slot.linked_field_id}"
|
||
)
|
||
|
||
return slot
|
||
|
||
async def update_slot_definition(
|
||
self,
|
||
tenant_id: str,
|
||
slot_id: str,
|
||
slot_update: SlotDefinitionUpdate,
|
||
) -> SlotDefinition | None:
|
||
"""
|
||
更新槽位定义
|
||
[AC-MRS-07-UPGRADE] 支持提取策略链更新
|
||
|
||
Args:
|
||
tenant_id: 租户 ID
|
||
slot_id: 槽位定义 ID
|
||
slot_update: 更新数据
|
||
|
||
Returns:
|
||
更新后的 SlotDefinition 或 None
|
||
"""
|
||
slot = await self.get_slot_definition(tenant_id, slot_id)
|
||
if not slot:
|
||
return None
|
||
|
||
if slot_update.display_name is not None:
|
||
slot.display_name = slot_update.display_name
|
||
|
||
if slot_update.description is not None:
|
||
slot.description = slot_update.description
|
||
|
||
if slot_update.type is not None:
|
||
if slot_update.type not in self.VALID_TYPES:
|
||
raise ValueError(
|
||
f"无效的槽位类型 '{slot_update.type}',"
|
||
f"有效类型为: {self.VALID_TYPES}"
|
||
)
|
||
slot.type = slot_update.type
|
||
|
||
if slot_update.required is not None:
|
||
slot.required = slot_update.required
|
||
|
||
# [AC-MRS-07-UPGRADE] 处理提取策略链更新
|
||
# 如果传入了 extract_strategies 或 extract_strategy,则更新
|
||
if slot_update.extract_strategies is not None or slot_update.extract_strategy is not None:
|
||
strategies = self._normalize_strategies(
|
||
slot_update.extract_strategies,
|
||
slot_update.extract_strategy
|
||
)
|
||
|
||
if strategies is not None:
|
||
is_valid, error_msg = self._validate_strategies(strategies)
|
||
if not is_valid:
|
||
raise ValueError(f"提取策略链校验失败: {error_msg}")
|
||
|
||
# [AC-MRS-07-UPGRADE] 同时更新新旧字段
|
||
slot.extract_strategies = strategies
|
||
# 如果前端提交了 extract_strategy,则使用它;否则使用策略链的第一个
|
||
if slot_update.extract_strategy is not None:
|
||
slot.extract_strategy = slot_update.extract_strategy
|
||
elif strategies and len(strategies) > 0:
|
||
slot.extract_strategy = strategies[0]
|
||
else:
|
||
slot.extract_strategy = None
|
||
|
||
if slot_update.validation_rule is not None:
|
||
slot.validation_rule = slot_update.validation_rule
|
||
|
||
if slot_update.ask_back_prompt is not None:
|
||
slot.ask_back_prompt = slot_update.ask_back_prompt
|
||
|
||
if slot_update.default_value is not None:
|
||
slot.default_value = slot_update.default_value
|
||
|
||
if slot_update.linked_field_id is not None:
|
||
if slot_update.linked_field_id:
|
||
linked_field = await self._get_linked_field(slot_update.linked_field_id)
|
||
if not linked_field:
|
||
raise ValueError(
|
||
f"[AC-MRS-08] 关联的元数据字段 '{slot_update.linked_field_id}' 不存在"
|
||
)
|
||
slot.linked_field_id = uuid.UUID(slot_update.linked_field_id)
|
||
else:
|
||
slot.linked_field_id = None
|
||
|
||
slot.updated_at = datetime.utcnow()
|
||
await self._session.flush()
|
||
|
||
logger.info(
|
||
f"[AC-MRS-07] Updated slot definition: tenant={tenant_id}, "
|
||
f"slot_id={slot_id}, strategies={slot.extract_strategies}"
|
||
)
|
||
|
||
return slot
|
||
|
||
async def delete_slot_definition(
|
||
self,
|
||
tenant_id: str,
|
||
slot_id: str,
|
||
) -> bool:
|
||
"""
|
||
[AC-MRS-16] 删除槽位定义
|
||
|
||
Args:
|
||
tenant_id: 租户 ID
|
||
slot_id: 槽位定义 ID
|
||
|
||
Returns:
|
||
是否删除成功
|
||
"""
|
||
slot = await self.get_slot_definition(tenant_id, slot_id)
|
||
if not slot:
|
||
return False
|
||
|
||
await self._session.delete(slot)
|
||
await self._session.flush()
|
||
|
||
logger.info(
|
||
f"[AC-MRS-16] Deleted slot definition: tenant={tenant_id}, "
|
||
f"slot_id={slot_id}"
|
||
)
|
||
|
||
return True
|
||
|
||
async def _get_linked_field(
|
||
self,
|
||
field_id: str,
|
||
) -> MetadataFieldDefinition | None:
|
||
"""
|
||
获取关联的元数据字段
|
||
|
||
Args:
|
||
field_id: 字段 ID
|
||
|
||
Returns:
|
||
MetadataFieldDefinition 或 None
|
||
"""
|
||
try:
|
||
field_uuid = uuid.UUID(field_id)
|
||
except ValueError:
|
||
return None
|
||
|
||
stmt = select(MetadataFieldDefinition).where(
|
||
MetadataFieldDefinition.id == field_uuid,
|
||
)
|
||
result = await self._session.execute(stmt)
|
||
return result.scalar_one_or_none()
|
||
|
||
async def get_slot_definition_with_field(
|
||
self,
|
||
tenant_id: str,
|
||
slot_id: str,
|
||
) -> dict[str, Any] | None:
|
||
"""
|
||
获取槽位定义及其关联字段信息
|
||
|
||
Args:
|
||
tenant_id: 租户 ID
|
||
slot_id: 槽位定义 ID
|
||
|
||
Returns:
|
||
包含槽位定义和关联字段的字典
|
||
"""
|
||
slot = await self.get_slot_definition(tenant_id, slot_id)
|
||
if not slot:
|
||
return None
|
||
|
||
result = {
|
||
"id": str(slot.id),
|
||
"tenant_id": slot.tenant_id,
|
||
"slot_key": slot.slot_key,
|
||
"display_name": slot.display_name,
|
||
"description": slot.description,
|
||
"type": slot.type,
|
||
"required": slot.required,
|
||
# [AC-MRS-07-UPGRADE] 返回新旧字段
|
||
"extract_strategy": slot.extract_strategy,
|
||
"extract_strategies": slot.extract_strategies,
|
||
"validation_rule": slot.validation_rule,
|
||
"ask_back_prompt": slot.ask_back_prompt,
|
||
"default_value": slot.default_value,
|
||
"linked_field_id": str(slot.linked_field_id) if slot.linked_field_id else None,
|
||
"created_at": slot.created_at.isoformat() if slot.created_at else None,
|
||
"updated_at": slot.updated_at.isoformat() if slot.updated_at else None,
|
||
"linked_field": None,
|
||
}
|
||
|
||
if slot.linked_field_id:
|
||
linked_field = await self._get_linked_field(str(slot.linked_field_id))
|
||
if linked_field:
|
||
result["linked_field"] = {
|
||
"id": str(linked_field.id),
|
||
"field_key": linked_field.field_key,
|
||
"label": linked_field.label,
|
||
"type": linked_field.type,
|
||
"required": linked_field.required,
|
||
"options": linked_field.options,
|
||
"default_value": linked_field.default_value,
|
||
"scope": linked_field.scope,
|
||
"is_filterable": linked_field.is_filterable,
|
||
"is_rank_feature": linked_field.is_rank_feature,
|
||
"field_roles": linked_field.field_roles,
|
||
"status": linked_field.status,
|
||
}
|
||
|
||
return result
|