""" Slot Manager Service. 槽位管理服务 - 统一槽位写入入口,集成校验逻辑 职责: 1. 在槽位值写入前执行校验 2. 管理槽位值的来源和置信度 3. 提供槽位写入的统一接口 4. 返回校验失败时的追问提示 """ import logging from typing import Any from sqlalchemy.ext.asyncio import AsyncSession from app.models.entities import SlotDefinition from app.models.mid.schemas import SlotSource from app.services.mid.slot_validation_service import ( BatchValidationResult, SlotValidationError, SlotValidationService, ) from app.services.slot_definition_service import SlotDefinitionService logger = logging.getLogger(__name__) class SlotWriteResult: """ 槽位写入结果 Attributes: success: 是否成功(校验通过并写入) slot_key: 槽位键名 value: 最终写入的值 error: 校验错误信息(校验失败时) ask_back_prompt: 追问提示语(校验失败时) """ def __init__( self, success: bool, slot_key: str, value: Any | None = None, error: SlotValidationError | None = None, ask_back_prompt: str | None = None, ): self.success = success self.slot_key = slot_key self.value = value self.error = error self.ask_back_prompt = ask_back_prompt def to_dict(self) -> dict[str, Any]: """转换为字典""" result = { "success": self.success, "slot_key": self.slot_key, "value": self.value, } if self.error: result["error"] = { "error_code": self.error.error_code, "error_message": self.error.error_message, } if self.ask_back_prompt: result["ask_back_prompt"] = self.ask_back_prompt return result class SlotManager: """ 槽位管理器 统一槽位写入入口,在写入前执行校验。 支持从 SlotDefinition 加载校验规则并执行。 """ def __init__( self, session: AsyncSession | None = None, validation_service: SlotValidationService | None = None, slot_def_service: SlotDefinitionService | None = None, ): """ 初始化槽位管理器 Args: session: 数据库会话 validation_service: 校验服务实例 slot_def_service: 槽位定义服务实例 """ self._session = session self._validation_service = validation_service or SlotValidationService() self._slot_def_service = slot_def_service self._slot_def_cache: dict[str, SlotDefinition | None] = {} async def write_slot( self, tenant_id: str, slot_key: str, value: Any, source: SlotSource = SlotSource.USER_CONFIRMED, confidence: float = 1.0, skip_validation: bool = False, ) -> SlotWriteResult: """ 写入单个槽位值(带校验) 执行流程: 1. 加载槽位定义 2. 执行校验(如果未跳过) 3. 返回校验结果 Args: tenant_id: 租户 ID slot_key: 槽位键名 value: 槽位值 source: 值来源 confidence: 置信度 skip_validation: 是否跳过校验(用于特殊场景) Returns: SlotWriteResult: 写入结果 """ # 加载槽位定义 slot_def = await self._get_slot_definition(tenant_id, slot_key) # 如果没有定义且非跳过校验,允许写入(动态槽位) if slot_def is None and skip_validation: logger.debug( f"[SlotManager] Writing slot without definition: " f"tenant_id={tenant_id}, slot_key={slot_key}" ) return SlotWriteResult( success=True, slot_key=slot_key, value=value, ) if slot_def is None: # 未定义槽位,允许写入但记录日志 logger.info( f"[SlotManager] Slot definition not found, allowing write: " f"tenant_id={tenant_id}, slot_key={slot_key}" ) return SlotWriteResult( success=True, slot_key=slot_key, value=value, ) # 执行校验 if not skip_validation: validation_result = self._validation_service.validate_slot_value( slot_def, value, tenant_id ) if not validation_result.ok: logger.info( f"[SlotManager] Slot validation failed: " f"tenant_id={tenant_id}, slot_key={slot_key}, " f"error_code={validation_result.error_code}" ) return SlotWriteResult( success=False, slot_key=slot_key, error=SlotValidationError( slot_key=slot_key, error_code=validation_result.error_code or "VALIDATION_FAILED", error_message=validation_result.error_message or "校验失败", ask_back_prompt=validation_result.ask_back_prompt, ), ask_back_prompt=validation_result.ask_back_prompt, ) # 使用归一化后的值 value = validation_result.normalized_value logger.debug( f"[SlotManager] Slot validation passed: " f"tenant_id={tenant_id}, slot_key={slot_key}" ) return SlotWriteResult( success=True, slot_key=slot_key, value=value, ) async def write_slots( self, tenant_id: str, values: dict[str, Any], source: SlotSource = SlotSource.USER_CONFIRMED, confidence: float = 1.0, skip_validation: bool = False, ) -> BatchValidationResult: """ 批量写入槽位值(带校验) Args: tenant_id: 租户 ID values: 槽位值字典 {slot_key: value} source: 值来源 confidence: 置信度 skip_validation: 是否跳过校验 Returns: BatchValidationResult: 批量校验结果 """ if skip_validation: return BatchValidationResult( ok=True, validated_values=values, ) # 加载所有相关槽位定义 slot_defs = await self._get_slot_definitions(tenant_id, list(values.keys())) # 执行批量校验 result = self._validation_service.validate_slots( slot_defs, values, tenant_id ) if not result.ok: logger.info( f"[SlotManager] Batch slot validation failed: " f"tenant_id={tenant_id}, errors={[e.slot_key for e in result.errors]}" ) else: logger.debug( f"[SlotManager] Batch slot validation passed: " f"tenant_id={tenant_id}, slots={list(values.keys())}" ) return result async def validate_before_write( self, tenant_id: str, slot_key: str, value: Any, ) -> tuple[bool, SlotValidationError | None]: """ 在写入前预校验槽位值 Args: tenant_id: 租户 ID slot_key: 槽位键名 value: 槽位值 Returns: Tuple of (是否通过, 错误信息) """ slot_def = await self._get_slot_definition(tenant_id, slot_key) if slot_def is None: # 未定义槽位,视为通过 return True, None result = self._validation_service.validate_slot_value( slot_def, value, tenant_id ) if result.ok: return True, None return False, SlotValidationError( slot_key=slot_key, error_code=result.error_code or "VALIDATION_FAILED", error_message=result.error_message or "校验失败", ask_back_prompt=result.ask_back_prompt, ) async def get_ask_back_prompt( self, tenant_id: str, slot_key: str, ) -> str | None: """ 获取槽位的追问提示语 Args: tenant_id: 租户 ID slot_key: 槽位键名 Returns: 追问提示语或 None """ slot_def = await self._get_slot_definition(tenant_id, slot_key) if slot_def is None: return None if isinstance(slot_def, SlotDefinition): return slot_def.ask_back_prompt return slot_def.get("ask_back_prompt") async def _get_slot_definition( self, tenant_id: str, slot_key: str, ) -> SlotDefinition | dict[str, Any] | None: """ 获取槽位定义(带缓存) Args: tenant_id: 租户 ID slot_key: 槽位键名 Returns: 槽位定义或 None """ cache_key = f"{tenant_id}:{slot_key}" if cache_key in self._slot_def_cache: return self._slot_def_cache[cache_key] slot_def = None if self._slot_def_service: slot_def = await self._slot_def_service.get_slot_definition_by_key( tenant_id, slot_key ) elif self._session: service = SlotDefinitionService(self._session) slot_def = await service.get_slot_definition_by_key(tenant_id, slot_key) self._slot_def_cache[cache_key] = slot_def return slot_def async def _get_slot_definitions( self, tenant_id: str, slot_keys: list[str], ) -> list[SlotDefinition | dict[str, Any]]: """ 批量获取槽位定义 Args: tenant_id: 租户 ID slot_keys: 槽位键名列表 Returns: 槽位定义列表 """ slot_defs = [] for key in slot_keys: slot_def = await self._get_slot_definition(tenant_id, key) if slot_def: slot_defs.append(slot_def) return slot_defs def clear_cache(self) -> None: """清除槽位定义缓存""" self._slot_def_cache.clear() def create_slot_manager( session: AsyncSession | None = None, validation_service: SlotValidationService | None = None, slot_def_service: SlotDefinitionService | None = None, ) -> SlotManager: """ 创建槽位管理器实例 Args: session: 数据库会话 validation_service: 校验服务实例 slot_def_service: 槽位定义服务实例 Returns: SlotManager: 槽位管理器实例 """ return SlotManager( session=session, validation_service=validation_service, slot_def_service=slot_def_service, )