ai-robot-core/ai-service/app/services/mid/slot_manager.py

380 lines
11 KiB
Python

"""
Slot Manager Service.
槽位管理服务 - 统一槽位写入入口,集成校验逻辑
职责:
1. 在槽位值写入前执行校验
2. 管理槽位值的来源和置信度
3. 提供槽位写入的统一接口
4. 返回校验失败时的追问提示
"""
import logging
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.entities import SlotDefinition
from app.models.mid.schemas import SlotSource
from app.services.mid.slot_validation_service import (
BatchValidationResult,
SlotValidationError,
SlotValidationService,
)
from app.services.slot_definition_service import SlotDefinitionService
logger = logging.getLogger(__name__)
class SlotWriteResult:
"""
槽位写入结果
Attributes:
success: 是否成功(校验通过并写入)
slot_key: 槽位键名
value: 最终写入的值
error: 校验错误信息(校验失败时)
ask_back_prompt: 追问提示语(校验失败时)
"""
def __init__(
self,
success: bool,
slot_key: str,
value: Any | None = None,
error: SlotValidationError | None = None,
ask_back_prompt: str | None = None,
):
self.success = success
self.slot_key = slot_key
self.value = value
self.error = error
self.ask_back_prompt = ask_back_prompt
def to_dict(self) -> dict[str, Any]:
"""转换为字典"""
result = {
"success": self.success,
"slot_key": self.slot_key,
"value": self.value,
}
if self.error:
result["error"] = {
"error_code": self.error.error_code,
"error_message": self.error.error_message,
}
if self.ask_back_prompt:
result["ask_back_prompt"] = self.ask_back_prompt
return result
class SlotManager:
"""
槽位管理器
统一槽位写入入口,在写入前执行校验。
支持从 SlotDefinition 加载校验规则并执行。
"""
def __init__(
self,
session: AsyncSession | None = None,
validation_service: SlotValidationService | None = None,
slot_def_service: SlotDefinitionService | None = None,
):
"""
初始化槽位管理器
Args:
session: 数据库会话
validation_service: 校验服务实例
slot_def_service: 槽位定义服务实例
"""
self._session = session
self._validation_service = validation_service or SlotValidationService()
self._slot_def_service = slot_def_service
self._slot_def_cache: dict[str, SlotDefinition | None] = {}
async def write_slot(
self,
tenant_id: str,
slot_key: str,
value: Any,
source: SlotSource = SlotSource.USER_CONFIRMED,
confidence: float = 1.0,
skip_validation: bool = False,
) -> SlotWriteResult:
"""
写入单个槽位值(带校验)
执行流程:
1. 加载槽位定义
2. 执行校验(如果未跳过)
3. 返回校验结果
Args:
tenant_id: 租户 ID
slot_key: 槽位键名
value: 槽位值
source: 值来源
confidence: 置信度
skip_validation: 是否跳过校验(用于特殊场景)
Returns:
SlotWriteResult: 写入结果
"""
# 加载槽位定义
slot_def = await self._get_slot_definition(tenant_id, slot_key)
# 如果没有定义且非跳过校验,允许写入(动态槽位)
if slot_def is None and skip_validation:
logger.debug(
f"[SlotManager] Writing slot without definition: "
f"tenant_id={tenant_id}, slot_key={slot_key}"
)
return SlotWriteResult(
success=True,
slot_key=slot_key,
value=value,
)
if slot_def is None:
# 未定义槽位,允许写入但记录日志
logger.info(
f"[SlotManager] Slot definition not found, allowing write: "
f"tenant_id={tenant_id}, slot_key={slot_key}"
)
return SlotWriteResult(
success=True,
slot_key=slot_key,
value=value,
)
# 执行校验
if not skip_validation:
validation_result = self._validation_service.validate_slot_value(
slot_def, value, tenant_id
)
if not validation_result.ok:
logger.info(
f"[SlotManager] Slot validation failed: "
f"tenant_id={tenant_id}, slot_key={slot_key}, "
f"error_code={validation_result.error_code}"
)
return SlotWriteResult(
success=False,
slot_key=slot_key,
error=SlotValidationError(
slot_key=slot_key,
error_code=validation_result.error_code or "VALIDATION_FAILED",
error_message=validation_result.error_message or "校验失败",
ask_back_prompt=validation_result.ask_back_prompt,
),
ask_back_prompt=validation_result.ask_back_prompt,
)
# 使用归一化后的值
value = validation_result.normalized_value
logger.debug(
f"[SlotManager] Slot validation passed: "
f"tenant_id={tenant_id}, slot_key={slot_key}"
)
return SlotWriteResult(
success=True,
slot_key=slot_key,
value=value,
)
async def write_slots(
self,
tenant_id: str,
values: dict[str, Any],
source: SlotSource = SlotSource.USER_CONFIRMED,
confidence: float = 1.0,
skip_validation: bool = False,
) -> BatchValidationResult:
"""
批量写入槽位值(带校验)
Args:
tenant_id: 租户 ID
values: 槽位值字典 {slot_key: value}
source: 值来源
confidence: 置信度
skip_validation: 是否跳过校验
Returns:
BatchValidationResult: 批量校验结果
"""
if skip_validation:
return BatchValidationResult(
ok=True,
validated_values=values,
)
# 加载所有相关槽位定义
slot_defs = await self._get_slot_definitions(tenant_id, list(values.keys()))
# 执行批量校验
result = self._validation_service.validate_slots(
slot_defs, values, tenant_id
)
if not result.ok:
logger.info(
f"[SlotManager] Batch slot validation failed: "
f"tenant_id={tenant_id}, errors={[e.slot_key for e in result.errors]}"
)
else:
logger.debug(
f"[SlotManager] Batch slot validation passed: "
f"tenant_id={tenant_id}, slots={list(values.keys())}"
)
return result
async def validate_before_write(
self,
tenant_id: str,
slot_key: str,
value: Any,
) -> tuple[bool, SlotValidationError | None]:
"""
在写入前预校验槽位值
Args:
tenant_id: 租户 ID
slot_key: 槽位键名
value: 槽位值
Returns:
Tuple of (是否通过, 错误信息)
"""
slot_def = await self._get_slot_definition(tenant_id, slot_key)
if slot_def is None:
# 未定义槽位,视为通过
return True, None
result = self._validation_service.validate_slot_value(
slot_def, value, tenant_id
)
if result.ok:
return True, None
return False, SlotValidationError(
slot_key=slot_key,
error_code=result.error_code or "VALIDATION_FAILED",
error_message=result.error_message or "校验失败",
ask_back_prompt=result.ask_back_prompt,
)
async def get_ask_back_prompt(
self,
tenant_id: str,
slot_key: str,
) -> str | None:
"""
获取槽位的追问提示语
Args:
tenant_id: 租户 ID
slot_key: 槽位键名
Returns:
追问提示语或 None
"""
slot_def = await self._get_slot_definition(tenant_id, slot_key)
if slot_def is None:
return None
if isinstance(slot_def, SlotDefinition):
return slot_def.ask_back_prompt
return slot_def.get("ask_back_prompt")
async def _get_slot_definition(
self,
tenant_id: str,
slot_key: str,
) -> SlotDefinition | dict[str, Any] | None:
"""
获取槽位定义(带缓存)
Args:
tenant_id: 租户 ID
slot_key: 槽位键名
Returns:
槽位定义或 None
"""
cache_key = f"{tenant_id}:{slot_key}"
if cache_key in self._slot_def_cache:
return self._slot_def_cache[cache_key]
slot_def = None
if self._slot_def_service:
slot_def = await self._slot_def_service.get_slot_definition_by_key(
tenant_id, slot_key
)
elif self._session:
service = SlotDefinitionService(self._session)
slot_def = await service.get_slot_definition_by_key(tenant_id, slot_key)
self._slot_def_cache[cache_key] = slot_def
return slot_def
async def _get_slot_definitions(
self,
tenant_id: str,
slot_keys: list[str],
) -> list[SlotDefinition | dict[str, Any]]:
"""
批量获取槽位定义
Args:
tenant_id: 租户 ID
slot_keys: 槽位键名列表
Returns:
槽位定义列表
"""
slot_defs = []
for key in slot_keys:
slot_def = await self._get_slot_definition(tenant_id, key)
if slot_def:
slot_defs.append(slot_def)
return slot_defs
def clear_cache(self) -> None:
"""清除槽位定义缓存"""
self._slot_def_cache.clear()
def create_slot_manager(
session: AsyncSession | None = None,
validation_service: SlotValidationService | None = None,
slot_def_service: SlotDefinitionService | None = None,
) -> SlotManager:
"""
创建槽位管理器实例
Args:
session: 数据库会话
validation_service: 校验服务实例
slot_def_service: 槽位定义服务实例
Returns:
SlotManager: 槽位管理器实例
"""
return SlotManager(
session=session,
validation_service=validation_service,
slot_def_service=slot_def_service,
)