ai-robot-core/ai-service/app/services/slot_definition_service.py

462 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
Slot Definition Service.
[AC-MRS-07, AC-MRS-08] 槽位定义管理服务
[AC-MRS-07-UPGRADE] 支持提取策略链 extract_strategies
"""
import logging
import re
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.entities import (
SlotDefinition,
SlotDefinitionCreate,
SlotDefinitionUpdate,
MetadataFieldDefinition,
MetadataFieldType,
ExtractStrategy,
)
logger = logging.getLogger(__name__)
class SlotDefinitionService:
"""
[AC-MRS-07, AC-MRS-08] 槽位定义服务
[AC-MRS-07-UPGRADE] 支持提取策略链管理
管理独立的槽位定义模型,与元数据字段解耦但可复用
"""
SLOT_KEY_PATTERN = re.compile(r"^[a-z][a-z0-9_]*$")
VALID_TYPES = ["string", "number", "boolean", "enum", "array_enum"]
VALID_EXTRACT_STRATEGIES = ["rule", "llm", "user_input"]
def __init__(self, session: AsyncSession):
self._session = session
async def list_slot_definitions(
self,
tenant_id: str,
required: bool | None = None,
) -> list[SlotDefinition]:
"""
列出租户所有槽位定义
Args:
tenant_id: 租户 ID
required: 按是否必填过滤
Returns:
SlotDefinition 列表
"""
stmt = select(SlotDefinition).where(
SlotDefinition.tenant_id == tenant_id,
)
if required is not None:
stmt = stmt.where(SlotDefinition.required == required)
stmt = stmt.order_by(SlotDefinition.created_at.desc())
result = await self._session.execute(stmt)
return list(result.scalars().all())
async def get_slot_definition(
self,
tenant_id: str,
slot_id: str,
) -> SlotDefinition | None:
"""
获取单个槽位定义
Args:
tenant_id: 租户 ID
slot_id: 槽位定义 ID
Returns:
SlotDefinition 或 None
"""
try:
slot_uuid = uuid.UUID(slot_id)
except ValueError:
return None
stmt = select(SlotDefinition).where(
SlotDefinition.tenant_id == tenant_id,
SlotDefinition.id == slot_uuid,
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def get_slot_definition_by_key(
self,
tenant_id: str,
slot_key: str,
) -> SlotDefinition | None:
"""
通过 slot_key 获取槽位定义
Args:
tenant_id: 租户 ID
slot_key: 槽位键名
Returns:
SlotDefinition 或 None
"""
stmt = select(SlotDefinition).where(
SlotDefinition.tenant_id == tenant_id,
SlotDefinition.slot_key == slot_key,
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
def _validate_strategies(self, strategies: list[str] | None) -> tuple[bool, str]:
"""
[AC-MRS-07-UPGRADE] 校验提取策略链的有效性
Args:
strategies: 策略链列表
Returns:
Tuple of (是否有效, 错误信息)
"""
if strategies is None:
return True, ""
if not isinstance(strategies, list):
return False, "extract_strategies 必须是数组类型"
if len(strategies) == 0:
return False, "提取策略链不能为空数组"
# 校验不允许重复策略
if len(strategies) != len(set(strategies)):
return False, "提取策略链中不允许重复的策略"
# 校验策略值有效
invalid = [s for s in strategies if s not in self.VALID_EXTRACT_STRATEGIES]
if invalid:
return False, f"无效的提取策略: {invalid},有效值为: {self.VALID_EXTRACT_STRATEGIES}"
return True, ""
def _normalize_strategies(
self,
extract_strategies: list[str] | None,
extract_strategy: str | None,
) -> list[str] | None:
"""
[AC-MRS-07-UPGRADE] 规范化提取策略
优先使用 extract_strategies如果不存在则使用 extract_strategy
Args:
extract_strategies: 策略链(新字段)
extract_strategy: 单策略(旧字段,兼容)
Returns:
规范化后的策略链或 None
"""
if extract_strategies is not None:
return extract_strategies
if extract_strategy:
return [extract_strategy]
return None
async def create_slot_definition(
self,
tenant_id: str,
slot_create: SlotDefinitionCreate,
) -> SlotDefinition:
"""
[AC-MRS-07, AC-MRS-08] 创建槽位定义
[AC-MRS-07-UPGRADE] 支持提取策略链
Args:
tenant_id: 租户 ID
slot_create: 创建数据
Returns:
创建的 SlotDefinition
Raises:
ValueError: 如果 slot_key 已存在或参数无效
"""
if not self.SLOT_KEY_PATTERN.match(slot_create.slot_key):
raise ValueError(
f"slot_key '{slot_create.slot_key}' 格式不正确,"
"必须以小写字母开头,仅允许小写字母、数字和下划线"
)
existing = await self.get_slot_definition_by_key(tenant_id, slot_create.slot_key)
if existing:
raise ValueError(f"slot_key '{slot_create.slot_key}' 已存在")
if slot_create.type not in self.VALID_TYPES:
raise ValueError(
f"无效的槽位类型 '{slot_create.type}'"
f"有效类型为: {self.VALID_TYPES}"
)
# [AC-MRS-07-UPGRADE] 规范化并校验提取策略链
strategies = self._normalize_strategies(
slot_create.extract_strategies,
slot_create.extract_strategy
)
if strategies is not None:
is_valid, error_msg = self._validate_strategies(strategies)
if not is_valid:
raise ValueError(f"提取策略链校验失败: {error_msg}")
linked_field = None
if slot_create.linked_field_id:
linked_field = await self._get_linked_field(slot_create.linked_field_id)
if not linked_field:
raise ValueError(
f"[AC-MRS-08] 关联的元数据字段 '{slot_create.linked_field_id}' 不存在"
)
# [AC-MRS-07-UPGRADE] 确定要保存的旧字段值
# 如果前端提交了 extract_strategies则使用第一个作为旧字段值
old_strategy = slot_create.extract_strategy
if not old_strategy and strategies and len(strategies) > 0:
old_strategy = strategies[0]
slot = SlotDefinition(
tenant_id=tenant_id,
slot_key=slot_create.slot_key,
display_name=slot_create.display_name,
description=slot_create.description,
type=slot_create.type,
required=slot_create.required,
# [AC-MRS-07-UPGRADE] 同时保存新旧字段
extract_strategy=old_strategy,
extract_strategies=strategies,
validation_rule=slot_create.validation_rule,
ask_back_prompt=slot_create.ask_back_prompt,
default_value=slot_create.default_value,
linked_field_id=uuid.UUID(slot_create.linked_field_id) if slot_create.linked_field_id else None,
)
self._session.add(slot)
await self._session.flush()
logger.info(
f"[AC-MRS-07] Created slot definition: tenant={tenant_id}, "
f"slot_key={slot.slot_key}, required={slot.required}, "
f"strategies={strategies}, "
f"linked_field_id={slot.linked_field_id}"
)
return slot
async def update_slot_definition(
self,
tenant_id: str,
slot_id: str,
slot_update: SlotDefinitionUpdate,
) -> SlotDefinition | None:
"""
更新槽位定义
[AC-MRS-07-UPGRADE] 支持提取策略链更新
Args:
tenant_id: 租户 ID
slot_id: 槽位定义 ID
slot_update: 更新数据
Returns:
更新后的 SlotDefinition 或 None
"""
slot = await self.get_slot_definition(tenant_id, slot_id)
if not slot:
return None
if slot_update.display_name is not None:
slot.display_name = slot_update.display_name
if slot_update.description is not None:
slot.description = slot_update.description
if slot_update.type is not None:
if slot_update.type not in self.VALID_TYPES:
raise ValueError(
f"无效的槽位类型 '{slot_update.type}'"
f"有效类型为: {self.VALID_TYPES}"
)
slot.type = slot_update.type
if slot_update.required is not None:
slot.required = slot_update.required
# [AC-MRS-07-UPGRADE] 处理提取策略链更新
# 如果传入了 extract_strategies 或 extract_strategy则更新
if slot_update.extract_strategies is not None or slot_update.extract_strategy is not None:
strategies = self._normalize_strategies(
slot_update.extract_strategies,
slot_update.extract_strategy
)
if strategies is not None:
is_valid, error_msg = self._validate_strategies(strategies)
if not is_valid:
raise ValueError(f"提取策略链校验失败: {error_msg}")
# [AC-MRS-07-UPGRADE] 同时更新新旧字段
slot.extract_strategies = strategies
# 如果前端提交了 extract_strategy则使用它否则使用策略链的第一个
if slot_update.extract_strategy is not None:
slot.extract_strategy = slot_update.extract_strategy
elif strategies and len(strategies) > 0:
slot.extract_strategy = strategies[0]
else:
slot.extract_strategy = None
if slot_update.validation_rule is not None:
slot.validation_rule = slot_update.validation_rule
if slot_update.ask_back_prompt is not None:
slot.ask_back_prompt = slot_update.ask_back_prompt
if slot_update.default_value is not None:
slot.default_value = slot_update.default_value
if slot_update.linked_field_id is not None:
if slot_update.linked_field_id:
linked_field = await self._get_linked_field(slot_update.linked_field_id)
if not linked_field:
raise ValueError(
f"[AC-MRS-08] 关联的元数据字段 '{slot_update.linked_field_id}' 不存在"
)
slot.linked_field_id = uuid.UUID(slot_update.linked_field_id)
else:
slot.linked_field_id = None
slot.updated_at = datetime.utcnow()
await self._session.flush()
logger.info(
f"[AC-MRS-07] Updated slot definition: tenant={tenant_id}, "
f"slot_id={slot_id}, strategies={slot.extract_strategies}"
)
return slot
async def delete_slot_definition(
self,
tenant_id: str,
slot_id: str,
) -> bool:
"""
[AC-MRS-16] 删除槽位定义
Args:
tenant_id: 租户 ID
slot_id: 槽位定义 ID
Returns:
是否删除成功
"""
slot = await self.get_slot_definition(tenant_id, slot_id)
if not slot:
return False
await self._session.delete(slot)
await self._session.flush()
logger.info(
f"[AC-MRS-16] Deleted slot definition: tenant={tenant_id}, "
f"slot_id={slot_id}"
)
return True
async def _get_linked_field(
self,
field_id: str,
) -> MetadataFieldDefinition | None:
"""
获取关联的元数据字段
Args:
field_id: 字段 ID
Returns:
MetadataFieldDefinition 或 None
"""
try:
field_uuid = uuid.UUID(field_id)
except ValueError:
return None
stmt = select(MetadataFieldDefinition).where(
MetadataFieldDefinition.id == field_uuid,
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def get_slot_definition_with_field(
self,
tenant_id: str,
slot_id: str,
) -> dict[str, Any] | None:
"""
获取槽位定义及其关联字段信息
Args:
tenant_id: 租户 ID
slot_id: 槽位定义 ID
Returns:
包含槽位定义和关联字段的字典
"""
slot = await self.get_slot_definition(tenant_id, slot_id)
if not slot:
return None
result = {
"id": str(slot.id),
"tenant_id": slot.tenant_id,
"slot_key": slot.slot_key,
"display_name": slot.display_name,
"description": slot.description,
"type": slot.type,
"required": slot.required,
# [AC-MRS-07-UPGRADE] 返回新旧字段
"extract_strategy": slot.extract_strategy,
"extract_strategies": slot.extract_strategies,
"validation_rule": slot.validation_rule,
"ask_back_prompt": slot.ask_back_prompt,
"default_value": slot.default_value,
"linked_field_id": str(slot.linked_field_id) if slot.linked_field_id else None,
"created_at": slot.created_at.isoformat() if slot.created_at else None,
"updated_at": slot.updated_at.isoformat() if slot.updated_at else None,
"linked_field": None,
}
if slot.linked_field_id:
linked_field = await self._get_linked_field(str(slot.linked_field_id))
if linked_field:
result["linked_field"] = {
"id": str(linked_field.id),
"field_key": linked_field.field_key,
"label": linked_field.label,
"type": linked_field.type,
"required": linked_field.required,
"options": linked_field.options,
"default_value": linked_field.default_value,
"scope": linked_field.scope,
"is_filterable": linked_field.is_filterable,
"is_rank_feature": linked_field.is_rank_feature,
"field_roles": linked_field.field_roles,
"status": linked_field.status,
}
return result