380 lines
11 KiB
Python
380 lines
11 KiB
Python
"""
|
|
Slot Manager Service.
|
|
槽位管理服务 - 统一槽位写入入口,集成校验逻辑
|
|
|
|
职责:
|
|
1. 在槽位值写入前执行校验
|
|
2. 管理槽位值的来源和置信度
|
|
3. 提供槽位写入的统一接口
|
|
4. 返回校验失败时的追问提示
|
|
"""
|
|
|
|
import logging
|
|
from typing import Any
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from app.models.entities import SlotDefinition
|
|
from app.models.mid.schemas import SlotSource
|
|
from app.services.mid.slot_validation_service import (
|
|
BatchValidationResult,
|
|
SlotValidationError,
|
|
SlotValidationService,
|
|
)
|
|
from app.services.slot_definition_service import SlotDefinitionService
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class SlotWriteResult:
|
|
"""
|
|
槽位写入结果
|
|
|
|
Attributes:
|
|
success: 是否成功(校验通过并写入)
|
|
slot_key: 槽位键名
|
|
value: 最终写入的值
|
|
error: 校验错误信息(校验失败时)
|
|
ask_back_prompt: 追问提示语(校验失败时)
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
success: bool,
|
|
slot_key: str,
|
|
value: Any | None = None,
|
|
error: SlotValidationError | None = None,
|
|
ask_back_prompt: str | None = None,
|
|
):
|
|
self.success = success
|
|
self.slot_key = slot_key
|
|
self.value = value
|
|
self.error = error
|
|
self.ask_back_prompt = ask_back_prompt
|
|
|
|
def to_dict(self) -> dict[str, Any]:
|
|
"""转换为字典"""
|
|
result = {
|
|
"success": self.success,
|
|
"slot_key": self.slot_key,
|
|
"value": self.value,
|
|
}
|
|
if self.error:
|
|
result["error"] = {
|
|
"error_code": self.error.error_code,
|
|
"error_message": self.error.error_message,
|
|
}
|
|
if self.ask_back_prompt:
|
|
result["ask_back_prompt"] = self.ask_back_prompt
|
|
return result
|
|
|
|
|
|
class SlotManager:
|
|
"""
|
|
槽位管理器
|
|
|
|
统一槽位写入入口,在写入前执行校验。
|
|
支持从 SlotDefinition 加载校验规则并执行。
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
session: AsyncSession | None = None,
|
|
validation_service: SlotValidationService | None = None,
|
|
slot_def_service: SlotDefinitionService | None = None,
|
|
):
|
|
"""
|
|
初始化槽位管理器
|
|
|
|
Args:
|
|
session: 数据库会话
|
|
validation_service: 校验服务实例
|
|
slot_def_service: 槽位定义服务实例
|
|
"""
|
|
self._session = session
|
|
self._validation_service = validation_service or SlotValidationService()
|
|
self._slot_def_service = slot_def_service
|
|
self._slot_def_cache: dict[str, SlotDefinition | None] = {}
|
|
|
|
async def write_slot(
|
|
self,
|
|
tenant_id: str,
|
|
slot_key: str,
|
|
value: Any,
|
|
source: SlotSource = SlotSource.USER_CONFIRMED,
|
|
confidence: float = 1.0,
|
|
skip_validation: bool = False,
|
|
) -> SlotWriteResult:
|
|
"""
|
|
写入单个槽位值(带校验)
|
|
|
|
执行流程:
|
|
1. 加载槽位定义
|
|
2. 执行校验(如果未跳过)
|
|
3. 返回校验结果
|
|
|
|
Args:
|
|
tenant_id: 租户 ID
|
|
slot_key: 槽位键名
|
|
value: 槽位值
|
|
source: 值来源
|
|
confidence: 置信度
|
|
skip_validation: 是否跳过校验(用于特殊场景)
|
|
|
|
Returns:
|
|
SlotWriteResult: 写入结果
|
|
"""
|
|
# 加载槽位定义
|
|
slot_def = await self._get_slot_definition(tenant_id, slot_key)
|
|
|
|
# 如果没有定义且非跳过校验,允许写入(动态槽位)
|
|
if slot_def is None and skip_validation:
|
|
logger.debug(
|
|
f"[SlotManager] Writing slot without definition: "
|
|
f"tenant_id={tenant_id}, slot_key={slot_key}"
|
|
)
|
|
return SlotWriteResult(
|
|
success=True,
|
|
slot_key=slot_key,
|
|
value=value,
|
|
)
|
|
|
|
if slot_def is None:
|
|
# 未定义槽位,允许写入但记录日志
|
|
logger.info(
|
|
f"[SlotManager] Slot definition not found, allowing write: "
|
|
f"tenant_id={tenant_id}, slot_key={slot_key}"
|
|
)
|
|
return SlotWriteResult(
|
|
success=True,
|
|
slot_key=slot_key,
|
|
value=value,
|
|
)
|
|
|
|
# 执行校验
|
|
if not skip_validation:
|
|
validation_result = self._validation_service.validate_slot_value(
|
|
slot_def, value, tenant_id
|
|
)
|
|
|
|
if not validation_result.ok:
|
|
logger.info(
|
|
f"[SlotManager] Slot validation failed: "
|
|
f"tenant_id={tenant_id}, slot_key={slot_key}, "
|
|
f"error_code={validation_result.error_code}"
|
|
)
|
|
return SlotWriteResult(
|
|
success=False,
|
|
slot_key=slot_key,
|
|
error=SlotValidationError(
|
|
slot_key=slot_key,
|
|
error_code=validation_result.error_code or "VALIDATION_FAILED",
|
|
error_message=validation_result.error_message or "校验失败",
|
|
ask_back_prompt=validation_result.ask_back_prompt,
|
|
),
|
|
ask_back_prompt=validation_result.ask_back_prompt,
|
|
)
|
|
|
|
# 使用归一化后的值
|
|
value = validation_result.normalized_value
|
|
|
|
logger.debug(
|
|
f"[SlotManager] Slot validation passed: "
|
|
f"tenant_id={tenant_id}, slot_key={slot_key}"
|
|
)
|
|
|
|
return SlotWriteResult(
|
|
success=True,
|
|
slot_key=slot_key,
|
|
value=value,
|
|
)
|
|
|
|
async def write_slots(
|
|
self,
|
|
tenant_id: str,
|
|
values: dict[str, Any],
|
|
source: SlotSource = SlotSource.USER_CONFIRMED,
|
|
confidence: float = 1.0,
|
|
skip_validation: bool = False,
|
|
) -> BatchValidationResult:
|
|
"""
|
|
批量写入槽位值(带校验)
|
|
|
|
Args:
|
|
tenant_id: 租户 ID
|
|
values: 槽位值字典 {slot_key: value}
|
|
source: 值来源
|
|
confidence: 置信度
|
|
skip_validation: 是否跳过校验
|
|
|
|
Returns:
|
|
BatchValidationResult: 批量校验结果
|
|
"""
|
|
if skip_validation:
|
|
return BatchValidationResult(
|
|
ok=True,
|
|
validated_values=values,
|
|
)
|
|
|
|
# 加载所有相关槽位定义
|
|
slot_defs = await self._get_slot_definitions(tenant_id, list(values.keys()))
|
|
|
|
# 执行批量校验
|
|
result = self._validation_service.validate_slots(
|
|
slot_defs, values, tenant_id
|
|
)
|
|
|
|
if not result.ok:
|
|
logger.info(
|
|
f"[SlotManager] Batch slot validation failed: "
|
|
f"tenant_id={tenant_id}, errors={[e.slot_key for e in result.errors]}"
|
|
)
|
|
else:
|
|
logger.debug(
|
|
f"[SlotManager] Batch slot validation passed: "
|
|
f"tenant_id={tenant_id}, slots={list(values.keys())}"
|
|
)
|
|
|
|
return result
|
|
|
|
async def validate_before_write(
|
|
self,
|
|
tenant_id: str,
|
|
slot_key: str,
|
|
value: Any,
|
|
) -> tuple[bool, SlotValidationError | None]:
|
|
"""
|
|
在写入前预校验槽位值
|
|
|
|
Args:
|
|
tenant_id: 租户 ID
|
|
slot_key: 槽位键名
|
|
value: 槽位值
|
|
|
|
Returns:
|
|
Tuple of (是否通过, 错误信息)
|
|
"""
|
|
slot_def = await self._get_slot_definition(tenant_id, slot_key)
|
|
|
|
if slot_def is None:
|
|
# 未定义槽位,视为通过
|
|
return True, None
|
|
|
|
result = self._validation_service.validate_slot_value(
|
|
slot_def, value, tenant_id
|
|
)
|
|
|
|
if result.ok:
|
|
return True, None
|
|
|
|
return False, SlotValidationError(
|
|
slot_key=slot_key,
|
|
error_code=result.error_code or "VALIDATION_FAILED",
|
|
error_message=result.error_message or "校验失败",
|
|
ask_back_prompt=result.ask_back_prompt,
|
|
)
|
|
|
|
async def get_ask_back_prompt(
|
|
self,
|
|
tenant_id: str,
|
|
slot_key: str,
|
|
) -> str | None:
|
|
"""
|
|
获取槽位的追问提示语
|
|
|
|
Args:
|
|
tenant_id: 租户 ID
|
|
slot_key: 槽位键名
|
|
|
|
Returns:
|
|
追问提示语或 None
|
|
"""
|
|
slot_def = await self._get_slot_definition(tenant_id, slot_key)
|
|
if slot_def is None:
|
|
return None
|
|
|
|
if isinstance(slot_def, SlotDefinition):
|
|
return slot_def.ask_back_prompt
|
|
return slot_def.get("ask_back_prompt")
|
|
|
|
async def _get_slot_definition(
|
|
self,
|
|
tenant_id: str,
|
|
slot_key: str,
|
|
) -> SlotDefinition | dict[str, Any] | None:
|
|
"""
|
|
获取槽位定义(带缓存)
|
|
|
|
Args:
|
|
tenant_id: 租户 ID
|
|
slot_key: 槽位键名
|
|
|
|
Returns:
|
|
槽位定义或 None
|
|
"""
|
|
cache_key = f"{tenant_id}:{slot_key}"
|
|
|
|
if cache_key in self._slot_def_cache:
|
|
return self._slot_def_cache[cache_key]
|
|
|
|
slot_def = None
|
|
if self._slot_def_service:
|
|
slot_def = await self._slot_def_service.get_slot_definition_by_key(
|
|
tenant_id, slot_key
|
|
)
|
|
elif self._session:
|
|
service = SlotDefinitionService(self._session)
|
|
slot_def = await service.get_slot_definition_by_key(tenant_id, slot_key)
|
|
|
|
self._slot_def_cache[cache_key] = slot_def
|
|
return slot_def
|
|
|
|
async def _get_slot_definitions(
|
|
self,
|
|
tenant_id: str,
|
|
slot_keys: list[str],
|
|
) -> list[SlotDefinition | dict[str, Any]]:
|
|
"""
|
|
批量获取槽位定义
|
|
|
|
Args:
|
|
tenant_id: 租户 ID
|
|
slot_keys: 槽位键名列表
|
|
|
|
Returns:
|
|
槽位定义列表
|
|
"""
|
|
slot_defs = []
|
|
for key in slot_keys:
|
|
slot_def = await self._get_slot_definition(tenant_id, key)
|
|
if slot_def:
|
|
slot_defs.append(slot_def)
|
|
return slot_defs
|
|
|
|
def clear_cache(self) -> None:
|
|
"""清除槽位定义缓存"""
|
|
self._slot_def_cache.clear()
|
|
|
|
|
|
def create_slot_manager(
|
|
session: AsyncSession | None = None,
|
|
validation_service: SlotValidationService | None = None,
|
|
slot_def_service: SlotDefinitionService | None = None,
|
|
) -> SlotManager:
|
|
"""
|
|
创建槽位管理器实例
|
|
|
|
Args:
|
|
session: 数据库会话
|
|
validation_service: 校验服务实例
|
|
slot_def_service: 槽位定义服务实例
|
|
|
|
Returns:
|
|
SlotManager: 槽位管理器实例
|
|
"""
|
|
return SlotManager(
|
|
session=session,
|
|
validation_service=validation_service,
|
|
slot_def_service=slot_def_service,
|
|
)
|