diff --git a/ai-service/app/api/admin/scene_slot_bundle.py b/ai-service/app/api/admin/scene_slot_bundle.py new file mode 100644 index 0000000..7d5bb42 --- /dev/null +++ b/ai-service/app/api/admin/scene_slot_bundle.py @@ -0,0 +1,265 @@ +""" +Scene Slot Bundle API. +[AC-SCENE-SLOT-01] 场景-槽位映射配置管理接口 +""" + +import logging +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, Query +from fastapi.responses import JSONResponse +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_session +from app.core.exceptions import MissingTenantIdException +from app.core.tenant import get_tenant_id +from app.models.entities import SceneSlotBundleCreate, SceneSlotBundleUpdate +from app.services.scene_slot_bundle_service import SceneSlotBundleService + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/admin/scene-slot-bundles", tags=["SceneSlotBundle"]) + + +def get_current_tenant_id() -> str: + """Get current tenant ID from context.""" + tenant_id = get_tenant_id() + if not tenant_id: + raise MissingTenantIdException() + return tenant_id + + +def _bundle_to_dict(bundle: Any) -> dict[str, Any]: + """Convert bundle to dict""" + return { + "id": str(bundle.id), + "tenant_id": str(bundle.tenant_id), + "scene_key": bundle.scene_key, + "scene_name": bundle.scene_name, + "description": bundle.description, + "required_slots": bundle.required_slots, + "optional_slots": bundle.optional_slots, + "slot_priority": bundle.slot_priority, + "completion_threshold": bundle.completion_threshold, + "ask_back_order": bundle.ask_back_order, + "status": bundle.status, + "version": bundle.version, + "created_at": bundle.created_at.isoformat() if bundle.created_at else None, + "updated_at": bundle.updated_at.isoformat() if bundle.updated_at else None, + } + + +@router.get( + "", + operation_id="listSceneSlotBundles", + summary="List scene slot bundles", + description="获取场景槽位包列表", +) +async def list_scene_slot_bundles( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + status: Annotated[str | None, Query( + description="按状态过滤: draft/active/deprecated" + )] = None, +) -> JSONResponse: + """ + 列出场景槽位包 + """ + logger.info( + f"Listing scene slot bundles: tenant={tenant_id}, status={status}" + ) + + service = SceneSlotBundleService(session) + bundles = await service.list_bundles(tenant_id, status) + + return JSONResponse( + content=[_bundle_to_dict(b) for b in bundles] + ) + + +@router.post( + "", + operation_id="createSceneSlotBundle", + summary="Create scene slot bundle", + description="[AC-SCENE-SLOT-01] 创建新的场景槽位包", + status_code=201, +) +async def create_scene_slot_bundle( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + bundle_create: SceneSlotBundleCreate, +) -> JSONResponse: + """ + [AC-SCENE-SLOT-01] 创建场景槽位包 + """ + logger.info( + f"[AC-SCENE-SLOT-01] Creating scene slot bundle: " + f"tenant={tenant_id}, scene_key={bundle_create.scene_key}" + ) + + service = SceneSlotBundleService(session) + + try: + bundle = await service.create_bundle(tenant_id, bundle_create) + await session.commit() + except ValueError as e: + return JSONResponse( + status_code=400, + content={ + "error_code": "VALIDATION_ERROR", + "message": str(e), + } + ) + + return JSONResponse( + status_code=201, + content=_bundle_to_dict(bundle) + ) + + +@router.get( + "/by-scene/{scene_key}", + operation_id="getSceneSlotBundleBySceneKey", + summary="Get scene slot bundle by scene key", + description="根据场景标识获取槽位包", +) +async def get_scene_slot_bundle_by_scene_key( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + scene_key: str, +) -> JSONResponse: + """ + 根据场景标识获取槽位包 + """ + logger.info( + f"Getting scene slot bundle by scene_key: tenant={tenant_id}, scene_key={scene_key}" + ) + + service = SceneSlotBundleService(session) + bundle = await service.get_bundle_by_scene_key(tenant_id, scene_key) + + if not bundle: + return JSONResponse( + status_code=404, + content={ + "error_code": "NOT_FOUND", + "message": f"Scene slot bundle with scene_key '{scene_key}' not found", + } + ) + + return JSONResponse(content=_bundle_to_dict(bundle)) + + +@router.get( + "/{id}", + operation_id="getSceneSlotBundle", + summary="Get scene slot bundle by ID", + description="获取单个场景槽位包详情(含槽位详情)", +) +async def get_scene_slot_bundle( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + id: str, +) -> JSONResponse: + """ + 获取单个场景槽位包详情 + """ + logger.info( + f"Getting scene slot bundle: tenant={tenant_id}, id={id}" + ) + + service = SceneSlotBundleService(session) + bundle = await service.get_bundle_with_slot_details(tenant_id, id) + + if not bundle: + return JSONResponse( + status_code=404, + content={ + "error_code": "NOT_FOUND", + "message": f"Scene slot bundle {id} not found", + } + ) + + return JSONResponse(content=bundle) + + +@router.put( + "/{id}", + operation_id="updateSceneSlotBundle", + summary="Update scene slot bundle", + description="更新场景槽位包", +) +async def update_scene_slot_bundle( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + id: str, + bundle_update: SceneSlotBundleUpdate, +) -> JSONResponse: + """ + 更新场景槽位包 + """ + logger.info( + f"Updating scene slot bundle: tenant={tenant_id}, id={id}" + ) + + service = SceneSlotBundleService(session) + + try: + bundle = await service.update_bundle(tenant_id, id, bundle_update) + except ValueError as e: + return JSONResponse( + status_code=400, + content={ + "error_code": "VALIDATION_ERROR", + "message": str(e), + } + ) + + if not bundle: + return JSONResponse( + status_code=404, + content={ + "error_code": "NOT_FOUND", + "message": f"Scene slot bundle {id} not found", + } + ) + + await session.commit() + + return JSONResponse(content=_bundle_to_dict(bundle)) + + +@router.delete( + "/{id}", + operation_id="deleteSceneSlotBundle", + summary="Delete scene slot bundle", + description="[AC-SCENE-SLOT-01] 删除场景槽位包", + status_code=204, +) +async def delete_scene_slot_bundle( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + id: str, +) -> JSONResponse: + """ + [AC-SCENE-SLOT-01] 删除场景槽位包 + """ + logger.info( + f"[AC-SCENE-SLOT-01] Deleting scene slot bundle: tenant={tenant_id}, id={id}" + ) + + service = SceneSlotBundleService(session) + success = await service.delete_bundle(tenant_id, id) + + if not success: + return JSONResponse( + status_code=404, + content={ + "error_code": "NOT_FOUND", + "message": f"Scene slot bundle not found: {id}", + } + ) + + await session.commit() + + return JSONResponse(status_code=204, content=None) diff --git a/ai-service/app/api/admin/slot_definition.py b/ai-service/app/api/admin/slot_definition.py index 090fdda..56b0da9 100644 --- a/ai-service/app/api/admin/slot_definition.py +++ b/ai-service/app/api/admin/slot_definition.py @@ -38,9 +38,13 @@ def _slot_to_dict(slot: dict[str, Any] | Any) -> dict[str, Any]: "id": str(slot.id), "tenant_id": str(slot.tenant_id), "slot_key": slot.slot_key, + "display_name": slot.display_name, + "description": slot.description, "type": slot.type, "required": slot.required, + # [AC-MRS-07-UPGRADE] 返回新旧字段 "extract_strategy": slot.extract_strategy, + "extract_strategies": slot.extract_strategies, "validation_rule": slot.validation_rule, "ask_back_prompt": slot.ask_back_prompt, "default_value": slot.default_value, diff --git a/ai-service/app/services/cache/scene_slot_bundle_cache.py b/ai-service/app/services/cache/scene_slot_bundle_cache.py new file mode 100644 index 0000000..ddc8c16 --- /dev/null +++ b/ai-service/app/services/cache/scene_slot_bundle_cache.py @@ -0,0 +1,274 @@ +""" +Scene Slot Bundle Cache Service. +[AC-SCENE-SLOT-03] 场景槽位包缓存服务 + +职责: +1. 缓存场景槽位包配置,减少数据库查询 +2. 支持缓存失效和刷新 +3. 支持租户隔离 +""" + +import json +import logging +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Any + +logger = logging.getLogger(__name__) + +try: + import redis.asyncio as redis + REDIS_AVAILABLE = True +except ImportError: + REDIS_AVAILABLE = False + + +@dataclass +class CachedSceneSlotBundle: + """缓存的场景槽位包""" + scene_key: str + scene_name: str + description: str | None + required_slots: list[str] + optional_slots: list[str] + slot_priority: list[str] | None + completion_threshold: float + ask_back_order: str + status: str + version: int + cached_at: datetime = field(default_factory=datetime.utcnow) + + def to_dict(self) -> dict[str, Any]: + return { + "scene_key": self.scene_key, + "scene_name": self.scene_name, + "description": self.description, + "required_slots": self.required_slots, + "optional_slots": self.optional_slots, + "slot_priority": self.slot_priority, + "completion_threshold": self.completion_threshold, + "ask_back_order": self.ask_back_order, + "status": self.status, + "version": self.version, + "cached_at": self.cached_at.isoformat(), + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "CachedSceneSlotBundle": + return cls( + scene_key=data["scene_key"], + scene_name=data["scene_name"], + description=data.get("description"), + required_slots=data.get("required_slots", []), + optional_slots=data.get("optional_slots", []), + slot_priority=data.get("slot_priority"), + completion_threshold=data.get("completion_threshold", 1.0), + ask_back_order=data.get("ask_back_order", "priority"), + status=data.get("status", "draft"), + version=data.get("version", 1), + cached_at=datetime.fromisoformat(data["cached_at"]) if data.get("cached_at") else datetime.utcnow(), + ) + + +class SceneSlotBundleCache: + """ + [AC-SCENE-SLOT-03] 场景槽位包缓存 + + 使用 Redis 或内存缓存场景槽位包配置 + """ + + CACHE_PREFIX = "scene_slot_bundle" + CACHE_TTL_SECONDS = 300 # 5分钟缓存 + + def __init__( + self, + redis_client: Any = None, + ttl_seconds: int = 300, + ): + self._redis = redis_client + self._ttl = ttl_seconds or self.CACHE_TTL_SECONDS + self._memory_cache: dict[str, tuple[CachedSceneSlotBundle, datetime]] = {} + + def _get_cache_key(self, tenant_id: str, scene_key: str) -> str: + """生成缓存键""" + return f"{self.CACHE_PREFIX}:{tenant_id}:{scene_key}" + + async def get( + self, + tenant_id: str, + scene_key: str, + ) -> CachedSceneSlotBundle | None: + """ + 获取缓存的场景槽位包 + + Args: + tenant_id: 租户 ID + scene_key: 场景标识 + + Returns: + 缓存的场景槽位包或 None + """ + cache_key = self._get_cache_key(tenant_id, scene_key) + + if self._redis and REDIS_AVAILABLE: + try: + cached_data = await self._redis.get(cache_key) + if cached_data: + data = json.loads(cached_data) + return CachedSceneSlotBundle.from_dict(data) + except Exception as e: + logger.warning( + f"[AC-SCENE-SLOT-03] Redis get failed: {e}, falling back to memory cache" + ) + + if cache_key in self._memory_cache: + cached_bundle, cached_at = self._memory_cache[cache_key] + if datetime.utcnow() - cached_at < timedelta(seconds=self._ttl): + return cached_bundle + else: + del self._memory_cache[cache_key] + + return None + + async def set( + self, + tenant_id: str, + scene_key: str, + bundle: CachedSceneSlotBundle, + ) -> bool: + """ + 设置缓存 + + Args: + tenant_id: 租户 ID + scene_key: 场景标识 + bundle: 要缓存的场景槽位包 + + Returns: + 是否成功 + """ + cache_key = self._get_cache_key(tenant_id, scene_key) + + if self._redis and REDIS_AVAILABLE: + try: + await self._redis.setex( + cache_key, + self._ttl, + json.dumps(bundle.to_dict()), + ) + return True + except Exception as e: + logger.warning( + f"[AC-SCENE-SLOT-03] Redis set failed: {e}, falling back to memory cache" + ) + + self._memory_cache[cache_key] = (bundle, datetime.utcnow()) + return True + + async def delete( + self, + tenant_id: str, + scene_key: str, + ) -> bool: + """ + 删除缓存 + + Args: + tenant_id: 租户 ID + scene_key: 场景标识 + + Returns: + 是否成功 + """ + cache_key = self._get_cache_key(tenant_id, scene_key) + + if self._redis and REDIS_AVAILABLE: + try: + await self._redis.delete(cache_key) + except Exception as e: + logger.warning( + f"[AC-SCENE-SLOT-03] Redis delete failed: {e}" + ) + + if cache_key in self._memory_cache: + del self._memory_cache[cache_key] + + return True + + async def delete_by_tenant(self, tenant_id: str) -> bool: + """ + 删除租户下所有缓存 + + Args: + tenant_id: 租户 ID + + Returns: + 是否成功 + """ + pattern = f"{self.CACHE_PREFIX}:{tenant_id}:*" + + if self._redis and REDIS_AVAILABLE: + try: + keys = [] + async for key in self._redis.scan_iter(match=pattern): + keys.append(key) + if keys: + await self._redis.delete(*keys) + except Exception as e: + logger.warning( + f"[AC-SCENE-SLOT-03] Redis delete by tenant failed: {e}" + ) + + keys_to_delete = [ + k for k in self._memory_cache + if k.startswith(f"{self.CACHE_PREFIX}:{tenant_id}:") + ] + for key in keys_to_delete: + del self._memory_cache[key] + + return True + + async def invalidate_on_update( + self, + tenant_id: str, + scene_key: str, + ) -> bool: + """ + 当场景槽位包更新时使缓存失效 + + Args: + tenant_id: 租户 ID + scene_key: 场景标识 + + Returns: + 是否成功 + """ + logger.info( + f"[AC-SCENE-SLOT-03] Invalidating cache for scene slot bundle: " + f"tenant={tenant_id}, scene={scene_key}" + ) + return await self.delete(tenant_id, scene_key) + + +_scene_slot_bundle_cache: SceneSlotBundleCache | None = None + + +def get_scene_slot_bundle_cache() -> SceneSlotBundleCache: + """获取场景槽位包缓存实例""" + global _scene_slot_bundle_cache + if _scene_slot_bundle_cache is None: + _scene_slot_bundle_cache = SceneSlotBundleCache() + return _scene_slot_bundle_cache + + +def init_scene_slot_bundle_cache(redis_client: Any = None, ttl_seconds: int = 300) -> None: + """初始化场景槽位包缓存""" + global _scene_slot_bundle_cache + _scene_slot_bundle_cache = SceneSlotBundleCache( + redis_client=redis_client, + ttl_seconds=ttl_seconds, + ) + logger.info( + f"[AC-SCENE-SLOT-03] Scene slot bundle cache initialized: " + f"ttl={ttl_seconds}s, redis={redis_client is not None}" + ) diff --git a/ai-service/app/services/cache/slot_state_cache.py b/ai-service/app/services/cache/slot_state_cache.py new file mode 100644 index 0000000..2952c4e --- /dev/null +++ b/ai-service/app/services/cache/slot_state_cache.py @@ -0,0 +1,397 @@ +""" +Slot State Cache Layer. +槽位状态缓存层 - 提供会话级槽位状态持久化 + +[AC-MRS-SLOT-CACHE-01] 多轮状态持久化 + +Features: +- L1: In-memory cache (process-level, 5 min TTL) +- L2: Redis cache (shared, configurable TTL) +- Automatic fallback on cache miss +- Support for slot value source priority + +Key format: slot_state:{tenant_id}:{session_id} +TTL: Configurable (default 30 minutes) +""" + +import json +import logging +import time +from dataclasses import asdict, dataclass, field +from datetime import datetime +from typing import Any + +import redis.asyncio as redis + +from app.core.config import get_settings +from app.models.mid.schemas import SlotSource + +logger = logging.getLogger(__name__) + + +@dataclass +class CachedSlotValue: + """ + 缓存的槽位值 + + Attributes: + value: 槽位值 + source: 值来源 (user_confirmed, rule_extracted, llm_inferred, default, context) + confidence: 置信度 + updated_at: 更新时间戳 + """ + value: Any + source: str + confidence: float = 1.0 + updated_at: float = field(default_factory=time.time) + + def to_dict(self) -> dict[str, Any]: + return { + "value": self.value, + "source": self.source, + "confidence": self.confidence, + "updated_at": self.updated_at, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "CachedSlotValue": + return cls( + value=data["value"], + source=data["source"], + confidence=data.get("confidence", 1.0), + updated_at=data.get("updated_at", time.time()), + ) + + +@dataclass +class CachedSlotState: + """ + 缓存的槽位状态 + + Attributes: + filled_slots: 已填充的槽位值字典 {slot_key: CachedSlotValue} + slot_to_field_map: 槽位到元数据字段的映射 + created_at: 创建时间 + updated_at: 最后更新时间 + """ + filled_slots: dict[str, CachedSlotValue] = field(default_factory=dict) + slot_to_field_map: dict[str, str] = field(default_factory=dict) + created_at: float = field(default_factory=time.time) + updated_at: float = field(default_factory=time.time) + + def to_dict(self) -> dict[str, Any]: + return { + "filled_slots": { + k: v.to_dict() for k, v in self.filled_slots.items() + }, + "slot_to_field_map": self.slot_to_field_map, + "created_at": self.created_at, + "updated_at": self.updated_at, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> "CachedSlotState": + filled_slots = {} + for k, v in data.get("filled_slots", {}).items(): + if isinstance(v, dict): + filled_slots[k] = CachedSlotValue.from_dict(v) + else: + filled_slots[k] = CachedSlotValue(value=v, source="unknown") + + return cls( + filled_slots=filled_slots, + slot_to_field_map=data.get("slot_to_field_map", {}), + created_at=data.get("created_at", time.time()), + updated_at=data.get("updated_at", time.time()), + ) + + def get_simple_filled_slots(self) -> dict[str, Any]: + """获取简化的已填充槽位字典(仅值)""" + return {k: v.value for k, v in self.filled_slots.items()} + + def get_slot_sources(self) -> dict[str, str]: + """获取槽位来源字典""" + return {k: v.source for k, v in self.filled_slots.items()} + + def get_slot_confidence(self) -> dict[str, float]: + """获取槽位置信度字典""" + return {k: v.confidence for k, v in self.filled_slots.items()} + + +class SlotStateCache: + """ + [AC-MRS-SLOT-CACHE-01] 槽位状态缓存层 + + 提供会话级槽位状态持久化,支持: + - L1: 内存缓存(进程级,5分钟 TTL) + - L2: Redis 缓存(共享,可配置 TTL) + - 自动降级(Redis 不可用时仅使用内存缓存) + - 槽位值来源优先级合并 + + Key format: slot_state:{tenant_id}:{session_id} + TTL: Configurable via settings.slot_state_cache_ttl (default 1800 seconds = 30 minutes) + """ + + _local_cache: dict[str, tuple[CachedSlotState, float]] = {} + _local_cache_ttl = 300 + + SOURCE_PRIORITY = { + SlotSource.USER_CONFIRMED.value: 100, + "user_confirmed": 100, + SlotSource.RULE_EXTRACTED.value: 80, + "rule_extracted": 80, + SlotSource.LLM_INFERRED.value: 60, + "llm_inferred": 60, + "context": 40, + SlotSource.DEFAULT.value: 20, + "default": 20, + "unknown": 0, + } + + def __init__(self, redis_client: redis.Redis | None = None): + self._redis = redis_client + self._settings = get_settings() + self._enabled = self._settings.redis_enabled + self._cache_ttl = getattr(self._settings, "slot_state_cache_ttl", 1800) + + async def _get_client(self) -> redis.Redis | None: + """Get or create Redis client.""" + if not self._enabled: + return None + if self._redis is None: + try: + self._redis = redis.from_url( + self._settings.redis_url, + encoding="utf-8", + decode_responses=True, + ) + except Exception as e: + logger.warning(f"[SlotStateCache] Failed to connect to Redis: {e}") + self._enabled = False + return None + return self._redis + + def _make_key(self, tenant_id: str, session_id: str) -> str: + """Generate cache key.""" + return f"slot_state:{tenant_id}:{session_id}" + + def _make_local_key(self, tenant_id: str, session_id: str) -> str: + """Generate local cache key.""" + return f"{tenant_id}:{session_id}" + + def _get_source_priority(self, source: str) -> int: + """Get priority for a source.""" + return self.SOURCE_PRIORITY.get(source, 0) + + async def get( + self, + tenant_id: str, + session_id: str, + ) -> CachedSlotState | None: + """ + Get cached slot state (L1 -> L2). + + Args: + tenant_id: Tenant ID for isolation + session_id: Session ID + + Returns: + CachedSlotState or None if not found + """ + local_key = self._make_local_key(tenant_id, session_id) + if local_key in self._local_cache: + state, timestamp = self._local_cache[local_key] + if time.time() - timestamp < self._local_cache_ttl: + logger.debug(f"[SlotStateCache] L1 hit: {local_key}") + return state + else: + del self._local_cache[local_key] + + client = await self._get_client() + if client is None: + return None + + key = self._make_key(tenant_id, session_id) + + try: + data = await client.get(key) + if data: + logger.debug(f"[SlotStateCache] L2 hit: {key}") + state_dict = json.loads(data) + state = CachedSlotState.from_dict(state_dict) + + self._local_cache[local_key] = (state, time.time()) + + return state + return None + except Exception as e: + logger.warning(f"[SlotStateCache] Failed to get from cache: {e}") + return None + + async def set( + self, + tenant_id: str, + session_id: str, + state: CachedSlotState, + ) -> bool: + """ + Set slot state to cache (L1 + L2). + + Args: + tenant_id: Tenant ID for isolation + session_id: Session ID + state: CachedSlotState to cache + + Returns: + True if successful + """ + local_key = self._make_local_key(tenant_id, session_id) + state.updated_at = time.time() + self._local_cache[local_key] = (state, time.time()) + + client = await self._get_client() + if client is None: + return False + + key = self._make_key(tenant_id, session_id) + + try: + await client.setex( + key, + self._cache_ttl, + json.dumps(state.to_dict(), default=str), + ) + logger.debug(f"[SlotStateCache] Set cache: {key}") + return True + except Exception as e: + logger.warning(f"[SlotStateCache] Failed to set cache: {e}") + return False + + async def merge_and_set( + self, + tenant_id: str, + session_id: str, + new_slots: dict[str, CachedSlotValue], + slot_to_field_map: dict[str, str] | None = None, + ) -> CachedSlotState: + """ + Merge new slot values with cached state and save. + + Priority: user_confirmed > rule_extracted > llm_inferred > context > default + + Args: + tenant_id: Tenant ID + session_id: Session ID + new_slots: New slot values to merge + slot_to_field_map: Slot to field mapping + + Returns: + Updated CachedSlotState + """ + state = await self.get(tenant_id, session_id) + if state is None: + state = CachedSlotState() + + for slot_key, new_value in new_slots.items(): + if slot_key in state.filled_slots: + existing = state.filled_slots[slot_key] + existing_priority = self._get_source_priority(existing.source) + new_priority = self._get_source_priority(new_value.source) + + if new_priority >= existing_priority: + state.filled_slots[slot_key] = new_value + logger.debug( + f"[SlotStateCache] Slot '{slot_key}' updated: " + f"{existing.source}({existing_priority}) -> " + f"{new_value.source}({new_priority})" + ) + else: + state.filled_slots[slot_key] = new_value + logger.debug( + f"[SlotStateCache] Slot '{slot_key}' added: " + f"source={new_value.source}, value={new_value.value}" + ) + + if slot_to_field_map: + state.slot_to_field_map.update(slot_to_field_map) + + await self.set(tenant_id, session_id, state) + + return state + + async def delete( + self, + tenant_id: str, + session_id: str, + ) -> bool: + """ + Delete slot state from cache (L1 + L2). + + Args: + tenant_id: Tenant ID for isolation + session_id: Session ID + + Returns: + True if successful + """ + local_key = self._make_local_key(tenant_id, session_id) + if local_key in self._local_cache: + del self._local_cache[local_key] + + client = await self._get_client() + if client is None: + return False + + key = self._make_key(tenant_id, session_id) + + try: + await client.delete(key) + logger.debug(f"[SlotStateCache] Deleted cache: {key}") + return True + except Exception as e: + logger.warning(f"[SlotStateCache] Failed to delete cache: {e}") + return False + + async def clear_slot( + self, + tenant_id: str, + session_id: str, + slot_key: str, + ) -> bool: + """ + Clear a specific slot from cached state. + + Args: + tenant_id: Tenant ID + session_id: Session ID + slot_key: Slot key to clear + + Returns: + True if successful + """ + state = await self.get(tenant_id, session_id) + if state is None: + return True + + if slot_key in state.filled_slots: + del state.filled_slots[slot_key] + await self.set(tenant_id, session_id, state) + logger.debug(f"[SlotStateCache] Cleared slot: {slot_key}") + + return True + + async def close(self) -> None: + """Close Redis connection.""" + if self._redis: + await self._redis.close() + + +_slot_state_cache: SlotStateCache | None = None + + +def get_slot_state_cache() -> SlotStateCache: + """Get singleton SlotStateCache instance.""" + global _slot_state_cache + if _slot_state_cache is None: + _slot_state_cache = SlotStateCache() + return _slot_state_cache diff --git a/ai-service/app/services/metadata_cache_service.py b/ai-service/app/services/metadata_cache_service.py new file mode 100644 index 0000000..9373fc3 --- /dev/null +++ b/ai-service/app/services/metadata_cache_service.py @@ -0,0 +1,202 @@ +""" +元数据字段定义缓存服务 +使用 Redis 缓存 metadata_field_definitions,减少数据库查询 +""" + +import json +import logging +from typing import Any + +from app.core.config import get_settings + +logger = logging.getLogger(__name__) + + +class MetadataCacheService: + """ + 元数据字段定义缓存服务 + + 缓存策略: + - Key: metadata:fields:{tenant_id} + - Value: JSON 序列化的字段定义列表 + - TTL: 1小时(3600秒) + - 更新策略:写时更新 + 定时刷新 + """ + + CACHE_KEY_PREFIX = "metadata:fields" + DEFAULT_TTL = 3600 # 1小时 + + def __init__(self): + self._settings = get_settings() + self._redis_client = None + self._enabled = self._settings.redis_enabled + + async def _get_redis(self): + """获取 Redis 连接(延迟初始化)""" + if not self._enabled: + return None + + if self._redis_client is None: + try: + import redis.asyncio as redis + self._redis_client = redis.from_url( + self._settings.redis_url, + decode_responses=True + ) + except Exception as e: + logger.error(f"[MetadataCache] Failed to connect Redis: {e}") + self._enabled = False + return None + + return self._redis_client + + def _make_key(self, tenant_id: str) -> str: + """生成缓存 key""" + return f"{self.CACHE_KEY_PREFIX}:{tenant_id}" + + async def get_fields(self, tenant_id: str) -> list[dict[str, Any]] | None: + """ + 获取缓存的字段定义 + + Args: + tenant_id: 租户 ID + + Returns: + 字段定义列表,未缓存返回 None + """ + if not self._enabled: + return None + + try: + redis = await self._get_redis() + if not redis: + return None + + key = self._make_key(tenant_id) + cached_data = await redis.get(key) + + if cached_data: + logger.info(f"[MetadataCache] Cache hit for tenant={tenant_id}") + return json.loads(cached_data) + + logger.info(f"[MetadataCache] Cache miss for tenant={tenant_id}") + return None + + except Exception as e: + logger.error(f"[MetadataCache] Get cache error: {e}") + return None + + async def set_fields( + self, + tenant_id: str, + fields: list[dict[str, Any]], + ttl: int | None = None + ) -> bool: + """ + 缓存字段定义 + + Args: + tenant_id: 租户 ID + fields: 字段定义列表 + ttl: 过期时间(秒),默认 1小时 + + Returns: + 是否成功 + """ + if not self._enabled: + return False + + try: + redis = await self._get_redis() + if not redis: + return False + + key = self._make_key(tenant_id) + ttl = ttl or self.DEFAULT_TTL + + await redis.setex( + key, + ttl, + json.dumps(fields, ensure_ascii=False, default=str) + ) + + logger.info( + f"[MetadataCache] Cached {len(fields)} fields for tenant={tenant_id}, " + f"ttl={ttl}s" + ) + return True + + except Exception as e: + logger.error(f"[MetadataCache] Set cache error: {e}") + return False + + async def invalidate(self, tenant_id: str) -> bool: + """ + 使缓存失效(字段定义更新时调用) + + Args: + tenant_id: 租户 ID + + Returns: + 是否成功 + """ + if not self._enabled: + return False + + try: + redis = await self._get_redis() + if not redis: + return False + + key = self._make_key(tenant_id) + result = await redis.delete(key) + + if result: + logger.info(f"[MetadataCache] Invalidated cache for tenant={tenant_id}") + return bool(result) + + except Exception as e: + logger.error(f"[MetadataCache] Invalidate error: {e}") + return False + + async def invalidate_all(self) -> bool: + """ + 使所有元数据缓存失效 + + Returns: + 是否成功 + """ + if not self._enabled: + return False + + try: + redis = await self._get_redis() + if not redis: + return False + + # 查找所有元数据缓存 key + pattern = f"{self.CACHE_KEY_PREFIX}:*" + keys = [] + async for key in redis.scan_iter(match=pattern): + keys.append(key) + + if keys: + await redis.delete(*keys) + logger.info(f"[MetadataCache] Invalidated {len(keys)} cache entries") + return True + + except Exception as e: + logger.error(f"[MetadataCache] Invalidate all error: {e}") + return False + + +# 全局缓存服务实例 +_metadata_cache_service: MetadataCacheService | None = None + + +async def get_metadata_cache_service() -> MetadataCacheService: + """获取元数据缓存服务实例(单例)""" + global _metadata_cache_service + if _metadata_cache_service is None: + _metadata_cache_service = MetadataCacheService() + return _metadata_cache_service diff --git a/ai-service/app/services/mid/batch_ask_back_service.py b/ai-service/app/services/mid/batch_ask_back_service.py new file mode 100644 index 0000000..55e972b --- /dev/null +++ b/ai-service/app/services/mid/batch_ask_back_service.py @@ -0,0 +1,342 @@ +""" +Batch Ask-Back Service. +批量追问服务 - 支持一次追问多个缺失槽位 + +[AC-MRS-SLOT-ASKBACK-01] 批量追问 + +职责: +1. 支持一次追问多个缺失槽位 +2. 选择策略:必填优先、场景相关优先、最近未追问过优先 +3. 输出形式:单条自然语言合并提问或分段提问 +""" + +import logging +import time +from dataclasses import dataclass, field +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.services.cache.slot_state_cache import get_slot_state_cache +from app.services.slot_definition_service import SlotDefinitionService + +logger = logging.getLogger(__name__) + + +@dataclass +class AskBackSlot: + """ + 追问槽位信息 + + Attributes: + slot_key: 槽位键名 + label: 显示标签 + ask_back_prompt: 追问提示 + priority: 优先级(数值越大越优先) + last_asked_at: 上次追问时间戳 + is_required: 是否必填 + scene_relevance: 场景相关度 + """ + slot_key: str + label: str + ask_back_prompt: str | None = None + priority: int = 0 + last_asked_at: float | None = None + is_required: bool = False + scene_relevance: float = 0.0 + + +@dataclass +class BatchAskBackConfig: + """ + 批量追问配置 + + Attributes: + max_ask_back_slots_per_turn: 每轮最多追问槽位数 + prefer_required: 是否优先追问必填槽位 + prefer_scene_relevant: 是否优先追问场景相关槽位 + avoid_recent_asked: 是否避免最近追问过的槽位 + recent_asked_threshold_seconds: 最近追问阈值(秒) + merge_prompts: 是否合并追问提示 + merge_template: 合并模板 + """ + max_ask_back_slots_per_turn: int = 2 + prefer_required: bool = True + prefer_scene_relevant: bool = True + avoid_recent_asked: bool = True + recent_asked_threshold_seconds: float = 60.0 + merge_prompts: bool = True + merge_template: str = "为了更好地为您服务,请告诉我:{prompts}" + + +@dataclass +class BatchAskBackResult: + """ + 批量追问结果 + + Attributes: + selected_slots: 选中的追问槽位列表 + prompts: 追问提示列表 + merged_prompt: 合并后的追问提示 + ask_back_count: 追问数量 + """ + selected_slots: list[AskBackSlot] = field(default_factory=list) + prompts: list[str] = field(default_factory=list) + merged_prompt: str | None = None + ask_back_count: int = 0 + + def has_ask_back(self) -> bool: + return self.ask_back_count > 0 + + def get_prompt(self) -> str: + """获取最终追问提示""" + if self.merged_prompt: + return self.merged_prompt + if self.prompts: + return self.prompts[0] + return "请提供更多信息以便我更好地帮助您。" + + +class BatchAskBackService: + """ + [AC-MRS-SLOT-ASKBACK-01] 批量追问服务 + + 支持一次追问多个缺失槽位,提高补全效率。 + """ + + ASK_BACK_HISTORY_KEY_PREFIX = "slot_ask_back_history" + ASK_BACK_HISTORY_TTL = 300 + + def __init__( + self, + session: AsyncSession, + tenant_id: str, + session_id: str, + config: BatchAskBackConfig | None = None, + ): + self._session = session + self._tenant_id = tenant_id + self._session_id = session_id + self._config = config or BatchAskBackConfig() + self._slot_def_service = SlotDefinitionService(session) + self._cache = get_slot_state_cache() + + async def generate_batch_ask_back( + self, + missing_slots: list[dict[str, str]], + current_scene: str | None = None, + ) -> BatchAskBackResult: + """ + 生成批量追问 + + Args: + missing_slots: 缺失槽位列表 + current_scene: 当前场景 + + Returns: + BatchAskBackResult: 批量追问结果 + """ + if not missing_slots: + return BatchAskBackResult() + + ask_back_slots = await self._prepare_ask_back_slots(missing_slots, current_scene) + + selected_slots = self._select_slots_for_ask_back(ask_back_slots) + + asked_history = await self._get_asked_history() + selected_slots = self._filter_recently_asked(selected_slots, asked_history) + + if not selected_slots: + selected_slots = ask_back_slots[:self._config.max_ask_back_slots_per_turn] + + prompts = self._generate_prompts(selected_slots) + merged_prompt = self._merge_prompts(prompts) if self._config.merge_prompts else None + + await self._record_ask_back_history([s.slot_key for s in selected_slots]) + + return BatchAskBackResult( + selected_slots=selected_slots, + prompts=prompts, + merged_prompt=merged_prompt, + ask_back_count=len(selected_slots), + ) + + async def _prepare_ask_back_slots( + self, + missing_slots: list[dict[str, str]], + current_scene: str | None, + ) -> list[AskBackSlot]: + """准备追问槽位列表""" + ask_back_slots = [] + + for missing in missing_slots: + slot_key = missing.get("slot_key", "") + label = missing.get("label", slot_key) + ask_back_prompt = missing.get("ask_back_prompt") + field_key = missing.get("field_key") + + slot_def = await self._slot_def_service.get_slot_definition_by_key( + self._tenant_id, slot_key + ) + + is_required = False + scene_relevance = 0.0 + + if slot_def: + is_required = slot_def.required + if not ask_back_prompt: + ask_back_prompt = slot_def.ask_back_prompt + + if current_scene and slot_def.scene_scope: + if current_scene in slot_def.scene_scope: + scene_relevance = 1.0 + + priority = self._calculate_priority(is_required, scene_relevance) + + ask_back_slots.append(AskBackSlot( + slot_key=slot_key, + label=label, + ask_back_prompt=ask_back_prompt, + priority=priority, + is_required=is_required, + scene_relevance=scene_relevance, + )) + + return ask_back_slots + + def _calculate_priority(self, is_required: bool, scene_relevance: float) -> int: + """计算槽位优先级""" + priority = 0 + + if self._config.prefer_required and is_required: + priority += 100 + + if self._config.prefer_scene_relevant: + priority += int(scene_relevance * 50) + + return priority + + def _select_slots_for_ask_back( + self, + ask_back_slots: list[AskBackSlot], + ) -> list[AskBackSlot]: + """选择要追问的槽位""" + sorted_slots = sorted( + ask_back_slots, + key=lambda s: s.priority, + reverse=True, + ) + + return sorted_slots[:self._config.max_ask_back_slots_per_turn] + + async def _get_asked_history(self) -> dict[str, float]: + """获取最近追问历史""" + history_key = f"{self.ASK_BACK_HISTORY_KEY_PREFIX}:{self._tenant_id}:{self._session_id}" + + try: + client = await self._cache._get_client() + if client is None: + return {} + + import json + data = await client.get(history_key) + if data: + return json.loads(data) + except Exception as e: + logger.warning(f"[BatchAskBack] Failed to get asked history: {e}") + + return {} + + async def _record_ask_back_history(self, slot_keys: list[str]) -> None: + """记录追问历史""" + history_key = f"{self.ASK_BACK_HISTORY_KEY_PREFIX}:{self._tenant_id}:{self._session_id}" + + try: + client = await self._cache._get_client() + if client is None: + return + + history = await self._get_asked_history() + current_time = time.time() + + for slot_key in slot_keys: + history[slot_key] = current_time + + import json + await client.setex( + history_key, + self.ASK_BACK_HISTORY_TTL, + json.dumps(history), + ) + except Exception as e: + logger.warning(f"[BatchAskBack] Failed to record asked history: {e}") + + def _filter_recently_asked( + self, + slots: list[AskBackSlot], + asked_history: dict[str, float], + ) -> list[AskBackSlot]: + """过滤最近追问过的槽位""" + if not self._config.avoid_recent_asked: + return slots + + current_time = time.time() + threshold = self._config.recent_asked_threshold_seconds + + return [ + slot for slot in slots + if slot.slot_key not in asked_history or + current_time - asked_history[slot.slot_key] > threshold + ] + + def _generate_prompts(self, slots: list[AskBackSlot]) -> list[str]: + """生成追问提示列表""" + prompts = [] + + for slot in slots: + if slot.ask_back_prompt: + prompts.append(slot.ask_back_prompt) + else: + prompts.append(f"请告诉我您的{slot.label}") + + return prompts + + def _merge_prompts(self, prompts: list[str]) -> str | None: + """合并追问提示""" + if not prompts: + return None + + if len(prompts) == 1: + return prompts[0] + + if len(prompts) == 2: + return f"{prompts[0]},以及{prompts[1]}" + + all_but_last = "、".join(prompts[:-1]) + return f"{all_but_last},以及{prompts[-1]}" + + +def create_batch_ask_back_service( + session: AsyncSession, + tenant_id: str, + session_id: str, + config: BatchAskBackConfig | None = None, +) -> BatchAskBackService: + """ + 创建批量追问服务实例 + + Args: + session: 数据库会话 + tenant_id: 租户 ID + session_id: 会话 ID + config: 配置 + + Returns: + BatchAskBackService: 批量追问服务实例 + """ + return BatchAskBackService( + session=session, + tenant_id=tenant_id, + session_id=session_id, + config=config, + ) diff --git a/ai-service/app/services/mid/scene_slot_bundle_loader.py b/ai-service/app/services/mid/scene_slot_bundle_loader.py new file mode 100644 index 0000000..3fa6610 --- /dev/null +++ b/ai-service/app/services/mid/scene_slot_bundle_loader.py @@ -0,0 +1,423 @@ +""" +Scene Slot Bundle Loader Service. +[AC-SCENE-SLOT-02] 运行时场景槽位包加载器 +[AC-SCENE-SLOT-03] 支持缓存层 + +职责: +1. 根据场景标识加载槽位包配置 +2. 聚合槽位定义详情 +3. 计算缺失槽位 +4. 生成追问提示 +""" + +import logging +from dataclasses import dataclass, field +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.entities import SceneSlotBundleStatus +from app.services.scene_slot_bundle_service import SceneSlotBundleService +from app.services.slot_definition_service import SlotDefinitionService +from app.services.cache.scene_slot_bundle_cache import ( + CachedSceneSlotBundle, + get_scene_slot_bundle_cache, +) + +logger = logging.getLogger(__name__) + + +@dataclass +class SlotInfo: + """槽位信息""" + slot_key: str + type: str + required: bool + ask_back_prompt: str | None = None + validation_rule: str | None = None + linked_field_id: str | None = None + default_value: Any = None + + +@dataclass +class SceneSlotContext: + """ + [AC-SCENE-SLOT-02] 场景槽位上下文 + + 运行时使用的场景槽位包信息 + """ + scene_key: str + scene_name: str + required_slots: list[SlotInfo] = field(default_factory=list) + optional_slots: list[SlotInfo] = field(default_factory=list) + slot_priority: list[str] = field(default_factory=list) + completion_threshold: float = 1.0 + ask_back_order: str = "priority" + + def get_all_slot_keys(self) -> list[str]: + """获取所有槽位键名""" + return [s.slot_key for s in self.required_slots] + [s.slot_key for s in self.optional_slots] + + def get_required_slot_keys(self) -> list[str]: + """获取必填槽位键名""" + return [s.slot_key for s in self.required_slots] + + def get_optional_slot_keys(self) -> list[str]: + """获取可选槽位键名""" + return [s.slot_key for s in self.optional_slots] + + def get_missing_slots( + self, + filled_slots: dict[str, Any], + ) -> list[dict[str, Any]]: + """ + 获取缺失的必填槽位信息 + + Args: + filled_slots: 已填充的槽位值 + + Returns: + 缺失槽位信息列表 + """ + missing = [] + + for slot_info in self.required_slots: + if slot_info.slot_key not in filled_slots: + missing.append({ + "slot_key": slot_info.slot_key, + "type": slot_info.type, + "required": True, + "ask_back_prompt": slot_info.ask_back_prompt, + "validation_rule": slot_info.validation_rule, + "linked_field_id": slot_info.linked_field_id, + }) + + return missing + + def get_ordered_missing_slots( + self, + filled_slots: dict[str, Any], + ) -> list[dict[str, Any]]: + """ + 按优先级顺序获取缺失的必填槽位 + + Args: + filled_slots: 已填充的槽位值 + + Returns: + 按优先级排序的缺失槽位信息列表 + """ + missing = self.get_missing_slots(filled_slots) + + if not missing: + return [] + + if self.ask_back_order == "required_first": + return missing + + if self.ask_back_order == "priority" and self.slot_priority: + priority_map = {slot_key: idx for idx, slot_key in enumerate(self.slot_priority)} + missing.sort(key=lambda x: priority_map.get(x["slot_key"], 999)) + + return missing + + def get_completion_ratio( + self, + filled_slots: dict[str, Any], + ) -> float: + """ + 计算完成比例 + + Args: + filled_slots: 已填充的槽位值 + + Returns: + 完成比例 (0.0 - 1.0) + """ + if not self.required_slots: + return 1.0 + + filled_count = sum( + 1 for slot_info in self.required_slots + if slot_info.slot_key in filled_slots + ) + + return filled_count / len(self.required_slots) + + def is_complete( + self, + filled_slots: dict[str, Any], + ) -> bool: + """ + 检查是否完成 + + Args: + filled_slots: 已填充的槽位值 + + Returns: + 是否达到完成阈值 + """ + return self.get_completion_ratio(filled_slots) >= self.completion_threshold + + +class SceneSlotBundleLoader: + """ + [AC-SCENE-SLOT-02] 场景槽位包加载器 + [AC-SCENE-SLOT-03] 支持缓存层 + + 运行时加载场景槽位包配置 + """ + + def __init__(self, session: AsyncSession, use_cache: bool = True): + self._session = session + self._bundle_service = SceneSlotBundleService(session) + self._slot_service = SlotDefinitionService(session) + self._use_cache = use_cache + self._cache = get_scene_slot_bundle_cache() if use_cache else None + + async def load_scene_context( + self, + tenant_id: str, + scene_key: str, + ) -> SceneSlotContext | None: + """ + 加载场景槽位上下文 + + Args: + tenant_id: 租户 ID + scene_key: 场景标识 + + Returns: + 场景槽位上下文或 None + """ + # [AC-SCENE-SLOT-03] 尝试从缓存获取 + if self._use_cache and self._cache: + cached = await self._cache.get(tenant_id, scene_key) + if cached and cached.status == SceneSlotBundleStatus.ACTIVE.value: + logger.debug( + f"[AC-SCENE-SLOT-03] Cache hit for scene: {scene_key}" + ) + return await self._build_context_from_cached(tenant_id, cached) + + bundle = await self._bundle_service.get_active_bundle_by_scene( + tenant_id=tenant_id, + scene_key=scene_key, + ) + + if not bundle: + logger.debug( + f"[AC-SCENE-SLOT-02] No active bundle found for scene: {scene_key}" + ) + return None + + # [AC-SCENE-SLOT-03] 写入缓存 + if self._use_cache and self._cache: + cached_bundle = CachedSceneSlotBundle( + scene_key=bundle.scene_key, + scene_name=bundle.scene_name, + description=bundle.description, + required_slots=bundle.required_slots, + optional_slots=bundle.optional_slots, + slot_priority=bundle.slot_priority, + completion_threshold=bundle.completion_threshold, + ask_back_order=bundle.ask_back_order, + status=bundle.status, + version=bundle.version, + ) + await self._cache.set(tenant_id, scene_key, cached_bundle) + + return await self._build_context_from_bundle(tenant_id, bundle) + + async def _build_context_from_cached( + self, + tenant_id: str, + cached: CachedSceneSlotBundle, + ) -> SceneSlotContext: + """从缓存构建场景槽位上下文""" + all_slots = await self._slot_service.list_slot_definitions(tenant_id) + slot_map = {slot.slot_key: slot for slot in all_slots} + + return self._build_context(cached, slot_map) + + async def _build_context_from_bundle( + self, + tenant_id: str, + bundle: Any, + ) -> SceneSlotContext: + """从数据库模型构建场景槽位上下文""" + all_slots = await self._slot_service.list_slot_definitions(tenant_id) + slot_map = {slot.slot_key: slot for slot in all_slots} + + cached = CachedSceneSlotBundle( + scene_key=bundle.scene_key, + scene_name=bundle.scene_name, + description=bundle.description, + required_slots=bundle.required_slots, + optional_slots=bundle.optional_slots, + slot_priority=bundle.slot_priority, + completion_threshold=bundle.completion_threshold, + ask_back_order=bundle.ask_back_order, + status=bundle.status, + version=bundle.version, + ) + + return self._build_context(cached, slot_map) + + def _build_context( + self, + cached: CachedSceneSlotBundle, + slot_map: dict[str, Any], + ) -> SceneSlotContext: + """构建场景槽位上下文""" + required_slot_infos = [] + for slot_key in cached.required_slots: + if slot_key in slot_map: + slot_def = slot_map[slot_key] + required_slot_infos.append(SlotInfo( + slot_key=slot_def.slot_key, + type=slot_def.type, + required=True, + ask_back_prompt=slot_def.ask_back_prompt, + validation_rule=slot_def.validation_rule, + linked_field_id=str(slot_def.linked_field_id) if slot_def.linked_field_id else None, + default_value=slot_def.default_value, + )) + else: + logger.warning( + f"[AC-SCENE-SLOT-02] Required slot not found: {slot_key}" + ) + + optional_slot_infos = [] + for slot_key in cached.optional_slots: + if slot_key in slot_map: + slot_def = slot_map[slot_key] + optional_slot_infos.append(SlotInfo( + slot_key=slot_def.slot_key, + type=slot_def.type, + required=slot_def.required, + ask_back_prompt=slot_def.ask_back_prompt, + validation_rule=slot_def.validation_rule, + linked_field_id=str(slot_def.linked_field_id) if slot_def.linked_field_id else None, + default_value=slot_def.default_value, + )) + else: + logger.warning( + f"[AC-SCENE-SLOT-02] Optional slot not found: {slot_key}" + ) + + context = SceneSlotContext( + scene_key=cached.scene_key, + scene_name=cached.scene_name, + required_slots=required_slot_infos, + optional_slots=optional_slot_infos, + slot_priority=cached.slot_priority or [], + completion_threshold=cached.completion_threshold, + ask_back_order=cached.ask_back_order, + ) + + logger.info( + f"[AC-SCENE-SLOT-02] Loaded scene context: scene={cached.scene_key}, " + f"required={len(required_slot_infos)}, optional={len(optional_slot_infos)}, " + f"threshold={cached.completion_threshold}" + ) + + return context + + async def invalidate_cache( + self, + tenant_id: str, + scene_key: str, + ) -> bool: + """ + [AC-SCENE-SLOT-03] 使缓存失效 + + Args: + tenant_id: 租户 ID + scene_key: 场景标识 + + Returns: + 是否成功 + """ + if self._cache: + return await self._cache.invalidate_on_update(tenant_id, scene_key) + return True + + async def get_missing_slots_for_scene( + self, + tenant_id: str, + scene_key: str, + filled_slots: dict[str, Any], + ) -> list[dict[str, Any]]: + """ + 获取场景缺失的必填槽位 + + Args: + tenant_id: 租户 ID + scene_key: 场景标识 + filled_slots: 已填充的槽位值 + + Returns: + 缺失槽位信息列表 + """ + context = await self.load_scene_context(tenant_id, scene_key) + + if not context: + return [] + + return context.get_ordered_missing_slots(filled_slots) + + async def generate_ask_back_prompt( + self, + tenant_id: str, + scene_key: str, + filled_slots: dict[str, Any], + max_slots: int = 2, + ) -> str | None: + """ + 生成追问提示 + + Args: + tenant_id: 租户 ID + scene_key: 场景标识 + filled_slots: 已填充的槽位值 + max_slots: 最多追问的槽位数量 + + Returns: + 追问提示或 None + """ + missing_slots = await self.get_missing_slots_for_scene( + tenant_id=tenant_id, + scene_key=scene_key, + filled_slots=filled_slots, + ) + + if not missing_slots: + return None + + context = await self.load_scene_context(tenant_id, scene_key) + ask_back_order = context.ask_back_order if context else "priority" + + if ask_back_order == "parallel": + prompts = [] + for missing in missing_slots[:max_slots]: + if missing.get("ask_back_prompt"): + prompts.append(missing["ask_back_prompt"]) + else: + slot_key = missing.get("slot_key", "相关信息") + prompts.append(f"您的{slot_key}") + + if len(prompts) == 1: + return prompts[0] + elif len(prompts) == 2: + return f"为了更好地为您服务,请告诉我{prompts[0]}和{prompts[1]}。" + else: + all_but_last = "、".join(prompts[:-1]) + return f"为了更好地为您服务,请告诉我{all_but_last},以及{prompts[-1]}。" + else: + first_missing = missing_slots[0] + ask_back_prompt = first_missing.get("ask_back_prompt") + if ask_back_prompt: + return ask_back_prompt + + slot_key = first_missing.get("slot_key", "相关信息") + return f"为了更好地为您提供帮助,请告诉我您的{slot_key}。" diff --git a/ai-service/app/services/mid/scene_slot_metrics.py b/ai-service/app/services/mid/scene_slot_metrics.py new file mode 100644 index 0000000..12412cd --- /dev/null +++ b/ai-service/app/services/mid/scene_slot_metrics.py @@ -0,0 +1,334 @@ +""" +Scene Slot Bundle Metrics Service. +[AC-SCENE-SLOT-04] 场景槽位包监控指标 + +职责: +1. 收集场景槽位包相关的监控指标 +2. 提供告警检测接口 +3. 支持指标导出 +""" + +import logging +from dataclasses import dataclass, field +from datetime import datetime +from typing import Any +from collections import defaultdict +import threading + +logger = logging.getLogger(__name__) + + +@dataclass +class SceneSlotMetricPoint: + """单个指标数据点""" + timestamp: datetime + tenant_id: str + scene_key: str + metric_name: str + value: float + tags: dict[str, str] = field(default_factory=dict) + + +@dataclass +class SceneSlotMetricsSummary: + """指标汇总""" + total_requests: int = 0 + cache_hits: int = 0 + cache_misses: int = 0 + missing_slots_triggered: int = 0 + ask_back_triggered: int = 0 + scene_not_configured: int = 0 + slot_not_found: int = 0 + avg_completion_ratio: float = 0.0 + + @property + def cache_hit_rate(self) -> float: + if self.total_requests == 0: + return 0.0 + return self.cache_hits / self.total_requests + + +class SceneSlotMetricsCollector: + """ + [AC-SCENE-SLOT-04] 场景槽位包指标收集器 + + 收集以下指标: + - scene_slot_requests_total: 场景槽位请求总数 + - scene_slot_cache_hits: 缓存命中次数 + - scene_slot_cache_misses: 缓存未命中次数 + - scene_slot_missing_triggered: 缺失槽位触发次数 + - scene_slot_ask_back_triggered: 追问触发次数 + - scene_slot_not_configured: 场景未配置次数 + - scene_slot_not_found: 槽位未找到次数 + - scene_slot_completion_ratio: 槽位完成比例 + """ + + def __init__(self, max_points: int = 10000): + self._max_points = max_points + self._points: list[SceneSlotMetricPoint] = [] + self._counters: dict[str, int] = defaultdict(int) + self._lock = threading.Lock() + self._start_time = datetime.utcnow() + + def record( + self, + tenant_id: str, + scene_key: str, + metric_name: str, + value: float = 1.0, + tags: dict[str, str] | None = None, + ) -> None: + """ + 记录指标 + + Args: + tenant_id: 租户 ID + scene_key: 场景标识 + metric_name: 指标名称 + value: 指标值 + tags: 额外标签 + """ + point = SceneSlotMetricPoint( + timestamp=datetime.utcnow(), + tenant_id=tenant_id, + scene_key=scene_key, + metric_name=metric_name, + value=value, + tags=tags or {}, + ) + + with self._lock: + self._points.append(point) + if len(self._points) > self._max_points: + self._points = self._points[-self._max_points:] + + counter_key = f"{tenant_id}:{scene_key}:{metric_name}" + self._counters[counter_key] += 1 + + def record_cache_hit(self, tenant_id: str, scene_key: str) -> None: + """记录缓存命中""" + self.record(tenant_id, scene_key, "cache_hit") + + def record_cache_miss(self, tenant_id: str, scene_key: str) -> None: + """记录缓存未命中""" + self.record(tenant_id, scene_key, "cache_miss") + + def record_missing_slots(self, tenant_id: str, scene_key: str, count: int = 1) -> None: + """记录缺失槽位触发""" + self.record(tenant_id, scene_key, "missing_slots_triggered", float(count)) + + def record_ask_back(self, tenant_id: str, scene_key: str) -> None: + """记录追问触发""" + self.record(tenant_id, scene_key, "ask_back_triggered") + + def record_scene_not_configured(self, tenant_id: str, scene_key: str) -> None: + """记录场景未配置""" + self.record(tenant_id, scene_key, "scene_not_configured") + + def record_slot_not_found(self, tenant_id: str, scene_key: str, slot_key: str) -> None: + """记录槽位未找到""" + self.record(tenant_id, scene_key, "slot_not_found", tags={"slot_key": slot_key}) + + def record_completion_ratio(self, tenant_id: str, scene_key: str, ratio: float) -> None: + """记录槽位完成比例""" + self.record(tenant_id, scene_key, "completion_ratio", ratio) + + def get_summary(self, tenant_id: str | None = None) -> SceneSlotMetricsSummary: + """ + 获取指标汇总 + + Args: + tenant_id: 租户 ID(可选,为 None 时返回所有租户的汇总) + + Returns: + 指标汇总 + """ + summary = SceneSlotMetricsSummary() + + with self._lock: + points = self._points.copy() + + for point in points: + if tenant_id and point.tenant_id != tenant_id: + continue + + summary.total_requests += 1 + + if point.metric_name == "cache_hit": + summary.cache_hits += 1 + elif point.metric_name == "cache_miss": + summary.cache_misses += 1 + elif point.metric_name == "missing_slots_triggered": + summary.missing_slots_triggered += int(point.value) + elif point.metric_name == "ask_back_triggered": + summary.ask_back_triggered += 1 + elif point.metric_name == "scene_not_configured": + summary.scene_not_configured += 1 + elif point.metric_name == "slot_not_found": + summary.slot_not_found += 1 + elif point.metric_name == "completion_ratio": + summary.avg_completion_ratio = ( + summary.avg_completion_ratio * summary.total_requests + point.value + ) / (summary.total_requests + 1) if summary.total_requests > 0 else point.value + + return summary + + def get_metrics_by_scene( + self, + tenant_id: str, + scene_key: str, + ) -> dict[str, Any]: + """ + 获取特定场景的指标 + + Args: + tenant_id: 租户 ID + scene_key: 场景标识 + + Returns: + 指标字典 + """ + metrics = { + "requests": 0, + "cache_hits": 0, + "cache_misses": 0, + "missing_slots_triggered": 0, + "ask_back_triggered": 0, + "slot_not_found": 0, + "avg_completion_ratio": 0.0, + } + + with self._lock: + points = [p for p in self._points if p.tenant_id == tenant_id and p.scene_key == scene_key] + + if not points: + return metrics + + completion_ratios = [] + + for point in points: + metrics["requests"] += 1 + if point.metric_name in metrics: + if point.metric_name == "completion_ratio": + completion_ratios.append(point.value) + else: + metrics[point.metric_name] += int(point.value) + + if completion_ratios: + metrics["avg_completion_ratio"] = sum(completion_ratios) / len(completion_ratios) + + return metrics + + def check_alerts( + self, + tenant_id: str, + scene_key: str, + thresholds: dict[str, float] | None = None, + ) -> list[dict[str, Any]]: + """ + 检查告警条件 + + Args: + tenant_id: 租户 ID + scene_key: 场景标识 + thresholds: 告警阈值配置 + + Returns: + 告警列表 + """ + default_thresholds = { + "cache_hit_rate_low": 0.5, + "missing_slots_rate_high": 0.3, + "scene_not_configured_rate_high": 0.1, + } + + effective_thresholds = {**default_thresholds, **(thresholds or {})} + + metrics = self.get_metrics_by_scene(tenant_id, scene_key) + alerts = [] + + if metrics["requests"] > 0: + cache_hit_rate = metrics["cache_hits"] / metrics["requests"] + if cache_hit_rate < effective_thresholds["cache_hit_rate_low"]: + alerts.append({ + "alert_type": "cache_hit_rate_low", + "severity": "warning", + "message": f"场景 {scene_key} 的缓存命中率 ({cache_hit_rate:.2%}) 低于阈值 ({effective_thresholds['cache_hit_rate_low']:.0%})", + "current_value": cache_hit_rate, + "threshold": effective_thresholds["cache_hit_rate_low"], + "suggestion": "检查场景槽位包配置是否频繁变更,或增加缓存 TTL", + }) + + missing_slots_rate = metrics["missing_slots_triggered"] / metrics["requests"] + if missing_slots_rate > effective_thresholds["missing_slots_rate_high"]: + alerts.append({ + "alert_type": "missing_slots_rate_high", + "severity": "warning", + "message": f"场景 {scene_key} 的缺失槽位触发率 ({missing_slots_rate:.2%}) 高于阈值 ({effective_thresholds['missing_slots_rate_high']:.0%})", + "current_value": missing_slots_rate, + "threshold": effective_thresholds["missing_slots_rate_high"], + "suggestion": "检查槽位配置是否合理,或优化槽位提取策略", + }) + + scene_not_configured_rate = metrics.get("scene_not_configured", 0) / metrics["requests"] + if scene_not_configured_rate > effective_thresholds["scene_not_configured_rate_high"]: + alerts.append({ + "alert_type": "scene_not_configured_rate_high", + "severity": "error", + "message": f"场景 {scene_key} 未配置率 ({scene_not_configured_rate:.2%}) 高于阈值 ({effective_thresholds['scene_not_configured_rate_high']:.0%})", + "current_value": scene_not_configured_rate, + "threshold": effective_thresholds["scene_not_configured_rate_high"], + "suggestion": "请为该场景创建场景槽位包配置", + }) + + return alerts + + def export_metrics(self, tenant_id: str | None = None) -> dict[str, Any]: + """ + 导出指标数据 + + Args: + tenant_id: 租户 ID(可选) + + Returns: + 指标数据字典 + """ + with self._lock: + points = [ + { + "timestamp": p.timestamp.isoformat(), + "tenant_id": p.tenant_id, + "scene_key": p.scene_key, + "metric_name": p.metric_name, + "value": p.value, + "tags": p.tags, + } + for p in self._points + if tenant_id is None or p.tenant_id == tenant_id + ] + + return { + "start_time": self._start_time.isoformat(), + "end_time": datetime.utcnow().isoformat(), + "total_points": len(points), + "points": points, + "summary": self.get_summary(tenant_id).__dict__, + } + + def reset(self) -> None: + """重置指标收集器""" + with self._lock: + self._points = [] + self._counters = defaultdict(int) + self._start_time = datetime.utcnow() + + +_metrics_collector: SceneSlotMetricsCollector | None = None + + +def get_scene_slot_metrics_collector() -> SceneSlotMetricsCollector: + """获取场景槽位指标收集器实例""" + global _metrics_collector + if _metrics_collector is None: + _metrics_collector = SceneSlotMetricsCollector() + return _metrics_collector diff --git a/ai-service/app/services/mid/slot_backfill_service.py b/ai-service/app/services/mid/slot_backfill_service.py new file mode 100644 index 0000000..7eeb9fd --- /dev/null +++ b/ai-service/app/services/mid/slot_backfill_service.py @@ -0,0 +1,500 @@ +""" +Slot Backfill Service. +槽位回填服务 - 处理槽位值的提取、校验、确认、写回 + +[AC-MRS-SLOT-BACKFILL-01] 槽位值回填确认 + +职责: +1. 从用户回复提取候选槽位值 +2. 调用 SlotManager 校验并归一化 +3. 校验失败返回 ask_back_prompt 二次追问 +4. 校验通过写入状态并标记 source/confidence +5. 对低置信度值增加确认话术 +""" + +import logging +from dataclasses import dataclass, field +from enum import Enum +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.mid.schemas import SlotSource +from app.services.mid.slot_manager import SlotManager, SlotWriteResult +from app.services.mid.slot_state_aggregator import SlotStateAggregator, SlotState +from app.services.mid.slot_strategy_executor import ( + ExtractContext, + SlotStrategyExecutor, + StrategyChainResult, +) +from app.services.slot_definition_service import SlotDefinitionService + +logger = logging.getLogger(__name__) + + +class BackfillStatus(str, Enum): + """回填状态""" + SUCCESS = "success" + VALIDATION_FAILED = "validation_failed" + EXTRACTION_FAILED = "extraction_failed" + NEEDS_CONFIRMATION = "needs_confirmation" + NO_CANDIDATES = "no_candidates" + + +@dataclass +class BackfillResult: + """ + 回填结果 + + Attributes: + status: 回填状态 + slot_key: 槽位键名 + value: 最终值(校验通过后) + normalized_value: 归一化后的值 + source: 值来源 + confidence: 置信度 + error_message: 错误信息 + ask_back_prompt: 追问提示 + confirmation_prompt: 确认提示(低置信度时) + updated_state: 更新后的槽位状态 + """ + status: BackfillStatus + slot_key: str | None = None + value: Any = None + normalized_value: Any = None + source: str = "unknown" + confidence: float = 0.0 + error_message: str | None = None + ask_back_prompt: str | None = None + confirmation_prompt: str | None = None + updated_state: SlotState | None = None + + def is_success(self) -> bool: + return self.status == BackfillStatus.SUCCESS + + def needs_ask_back(self) -> bool: + return self.status in ( + BackfillStatus.VALIDATION_FAILED, + BackfillStatus.EXTRACTION_FAILED, + ) + + def needs_confirmation(self) -> bool: + return self.status == BackfillStatus.NEEDS_CONFIRMATION + + def to_dict(self) -> dict[str, Any]: + return { + "status": self.status.value, + "slot_key": self.slot_key, + "value": self.value, + "normalized_value": self.normalized_value, + "source": self.source, + "confidence": self.confidence, + "error_message": self.error_message, + "ask_back_prompt": self.ask_back_prompt, + "confirmation_prompt": self.confirmation_prompt, + } + + +@dataclass +class BatchBackfillResult: + """批量回填结果""" + results: list[BackfillResult] = field(default_factory=list) + success_count: int = 0 + failed_count: int = 0 + confirmation_needed_count: int = 0 + + def add_result(self, result: BackfillResult) -> None: + self.results.append(result) + if result.is_success(): + self.success_count += 1 + elif result.needs_confirmation(): + self.confirmation_needed_count += 1 + else: + self.failed_count += 1 + + def get_ask_back_prompts(self) -> list[str]: + """获取所有追问提示""" + return [ + r.ask_back_prompt + for r in self.results + if r.ask_back_prompt + ] + + def get_confirmation_prompts(self) -> list[str]: + """获取所有确认提示""" + return [ + r.confirmation_prompt + for r in self.results + if r.confirmation_prompt + ] + + +class SlotBackfillService: + """ + [AC-MRS-SLOT-BACKFILL-01] 槽位回填服务 + + 处理槽位值的提取、校验、确认、写回流程: + 1. 从用户回复提取候选槽位值 + 2. SlotManager 校验并归一化 + 3. 校验失败返回 ask_back_prompt 二次追问 + 4. 校验通过写入状态并标记 source/confidence + 5. 对低置信度值增加确认话术 + """ + + CONFIDENCE_THRESHOLD_LOW = 0.5 + CONFIDENCE_THRESHOLD_HIGH = 0.8 + + def __init__( + self, + session: AsyncSession, + tenant_id: str, + session_id: str | None = None, + slot_manager: SlotManager | None = None, + strategy_executor: SlotStrategyExecutor | None = None, + ): + self._session = session + self._tenant_id = tenant_id + self._session_id = session_id + self._slot_manager = slot_manager or SlotManager(session=session) + self._strategy_executor = strategy_executor or SlotStrategyExecutor() + self._slot_def_service = SlotDefinitionService(session) + self._state_aggregator: SlotStateAggregator | None = None + + async def _get_state_aggregator(self) -> SlotStateAggregator: + """获取状态聚合器""" + if self._state_aggregator is None: + self._state_aggregator = SlotStateAggregator( + session=self._session, + tenant_id=self._tenant_id, + session_id=self._session_id, + ) + return self._state_aggregator + + async def backfill_single_slot( + self, + slot_key: str, + candidate_value: Any, + source: str = "user_confirmed", + confidence: float = 1.0, + strategies: list[str] | None = None, + ) -> BackfillResult: + """ + 回填单个槽位 + + 执行流程: + 1. 如果有提取策略,执行提取 + 2. 校验候选值 + 3. 根据校验结果决定下一步 + 4. 写入状态 + + Args: + slot_key: 槽位键名 + candidate_value: 候选值 + source: 值来源 + confidence: 初始置信度 + strategies: 提取策略链(可选) + + Returns: + BackfillResult: 回填结果 + """ + final_value = candidate_value + final_source = source + final_confidence = confidence + + if strategies: + extracted_result = await self._extract_value( + slot_key=slot_key, + user_input=str(candidate_value), + strategies=strategies, + ) + + if extracted_result.success: + final_value = extracted_result.final_value + final_source = extracted_result.final_strategy or source + final_confidence = self._get_confidence_for_strategy(final_source) + else: + ask_back_prompt = await self._slot_manager.get_ask_back_prompt( + self._tenant_id, slot_key + ) + return BackfillResult( + status=BackfillStatus.EXTRACTION_FAILED, + slot_key=slot_key, + value=candidate_value, + source=source, + confidence=confidence, + error_message=extracted_result.steps[-1].failure_reason if extracted_result.steps else "提取失败", + ask_back_prompt=ask_back_prompt, + ) + + write_result = await self._slot_manager.write_slot( + tenant_id=self._tenant_id, + slot_key=slot_key, + value=final_value, + source=SlotSource(final_source) if final_source in [s.value for s in SlotSource] else SlotSource.USER_CONFIRMED, + confidence=final_confidence, + ) + + if not write_result.success: + return BackfillResult( + status=BackfillStatus.VALIDATION_FAILED, + slot_key=slot_key, + value=final_value, + source=final_source, + confidence=final_confidence, + error_message=write_result.error.error_message if write_result.error else "校验失败", + ask_back_prompt=write_result.ask_back_prompt, + ) + + normalized_value = write_result.value + updated_state = None + + if self._session_id: + aggregator = await self._get_state_aggregator() + updated_state = await aggregator.update_slot( + slot_key=slot_key, + value=normalized_value, + source=final_source, + confidence=final_confidence, + ) + + result_status = BackfillStatus.SUCCESS + confirmation_prompt = None + + if final_confidence < self.CONFIDENCE_THRESHOLD_LOW: + result_status = BackfillStatus.NEEDS_CONFIRMATION + confirmation_prompt = self._generate_confirmation_prompt( + slot_key, normalized_value + ) + + return BackfillResult( + status=result_status, + slot_key=slot_key, + value=final_value, + normalized_value=normalized_value, + source=final_source, + confidence=final_confidence, + confirmation_prompt=confirmation_prompt, + updated_state=updated_state, + ) + + async def backfill_multiple_slots( + self, + candidates: dict[str, Any], + source: str = "user_confirmed", + confidence: float = 1.0, + ) -> BatchBackfillResult: + """ + 批量回填槽位 + + Args: + candidates: 候选值字典 {slot_key: value} + source: 值来源 + confidence: 初始置信度 + + Returns: + BatchBackfillResult: 批量回填结果 + """ + batch_result = BatchBackfillResult() + + for slot_key, value in candidates.items(): + result = await self.backfill_single_slot( + slot_key=slot_key, + candidate_value=value, + source=source, + confidence=confidence, + ) + batch_result.add_result(result) + + return batch_result + + async def backfill_from_user_response( + self, + user_response: str, + expected_slots: list[str], + strategies: list[str] | None = None, + ) -> BatchBackfillResult: + """ + 从用户回复中提取并回填槽位 + + Args: + user_response: 用户回复文本 + expected_slots: 期望提取的槽位列表 + strategies: 提取策略链 + + Returns: + BatchBackfillResult: 批量回填结果 + """ + batch_result = BatchBackfillResult() + + for slot_key in expected_slots: + slot_def = await self._slot_def_service.get_slot_definition_by_key( + self._tenant_id, slot_key + ) + + if not slot_def: + continue + + extract_strategies = strategies or ["rule", "llm"] + + extracted_result = await self._extract_value( + slot_key=slot_key, + user_input=user_response, + strategies=extract_strategies, + slot_type=slot_def.type, + validation_rule=slot_def.validation_rule, + ) + + if not extracted_result.success: + ask_back_prompt = slot_def.ask_back_prompt or f"请提供{slot_key}信息" + batch_result.add_result(BackfillResult( + status=BackfillStatus.EXTRACTION_FAILED, + slot_key=slot_key, + error_message="无法从回复中提取", + ask_back_prompt=ask_back_prompt, + )) + continue + + source = self._get_source_for_strategy(extracted_result.final_strategy) + confidence = self._get_confidence_for_strategy(source) + + result = await self.backfill_single_slot( + slot_key=slot_key, + candidate_value=extracted_result.final_value, + source=source, + confidence=confidence, + ) + batch_result.add_result(result) + + return batch_result + + async def _extract_value( + self, + slot_key: str, + user_input: str, + strategies: list[str], + slot_type: str = "string", + validation_rule: str | None = None, + ) -> StrategyChainResult: + """ + 执行槽位值提取 + + Args: + slot_key: 槽位键名 + user_input: 用户输入 + strategies: 提取策略链 + slot_type: 槽位类型 + validation_rule: 校验规则 + + Returns: + StrategyChainResult: 提取结果 + """ + context = ExtractContext( + tenant_id=self._tenant_id, + slot_key=slot_key, + user_input=user_input, + slot_type=slot_type, + validation_rule=validation_rule, + ) + + return await self._strategy_executor.execute_chain( + strategies=strategies, + context=context, + ) + + def _get_source_for_strategy(self, strategy: str | None) -> str: + """根据策略获取来源""" + strategy_source_map = { + "rule": SlotSource.RULE_EXTRACTED.value, + "llm": SlotSource.LLM_INFERRED.value, + "user_input": SlotSource.USER_CONFIRMED.value, + } + return strategy_source_map.get(strategy or "", "unknown") + + def _get_confidence_for_strategy(self, source: str) -> float: + """根据来源获取置信度""" + confidence_map = { + SlotSource.USER_CONFIRMED.value: 1.0, + SlotSource.RULE_EXTRACTED.value: 0.9, + SlotSource.LLM_INFERRED.value: 0.7, + "context": 0.5, + SlotSource.DEFAULT.value: 0.3, + } + return confidence_map.get(source, 0.5) + + def _generate_confirmation_prompt( + self, + slot_key: str, + value: Any, + ) -> str: + """生成确认提示""" + return f"我理解您说的是「{value}」,对吗?" + + async def confirm_low_confidence_slot( + self, + slot_key: str, + confirmed: bool, + ) -> BackfillResult: + """ + 确认低置信度槽位 + + Args: + slot_key: 槽位键名 + confirmed: 用户是否确认 + + Returns: + BackfillResult: 确认结果 + """ + if not self._session_id: + return BackfillResult( + status=BackfillStatus.SUCCESS, + slot_key=slot_key, + ) + + aggregator = await self._get_state_aggregator() + + if confirmed: + updated_state = await aggregator.update_slot( + slot_key=slot_key, + source=SlotSource.USER_CONFIRMED.value, + confidence=1.0, + ) + return BackfillResult( + status=BackfillStatus.SUCCESS, + slot_key=slot_key, + source=SlotSource.USER_CONFIRMED.value, + confidence=1.0, + updated_state=updated_state, + ) + else: + await aggregator.clear_slot(slot_key) + ask_back_prompt = await self._slot_manager.get_ask_back_prompt( + self._tenant_id, slot_key + ) + return BackfillResult( + status=BackfillStatus.VALIDATION_FAILED, + slot_key=slot_key, + ask_back_prompt=ask_back_prompt or f"请重新提供{slot_key}信息", + ) + + +def create_slot_backfill_service( + session: AsyncSession, + tenant_id: str, + session_id: str | None = None, +) -> SlotBackfillService: + """ + 创建槽位回填服务实例 + + Args: + session: 数据库会话 + tenant_id: 租户 ID + session_id: 会话 ID + + Returns: + SlotBackfillService: 槽位回填服务实例 + """ + return SlotBackfillService( + session=session, + tenant_id=tenant_id, + session_id=session_id, + ) diff --git a/ai-service/app/services/mid/slot_extraction_integration.py b/ai-service/app/services/mid/slot_extraction_integration.py new file mode 100644 index 0000000..7806bc0 --- /dev/null +++ b/ai-service/app/services/mid/slot_extraction_integration.py @@ -0,0 +1,368 @@ +""" +Slot Extraction Integration Service. +槽位提取集成服务 - 将自动提取能力接入主链路 + +[AC-MRS-SLOT-EXTRACT-01] slot extraction 集成 + +职责: +1. 接入点:memory_recall 之后、KB 检索之前 +2. 执行策略链:rule -> llm -> user_input +3. 抽取结果统一走 SlotManager 校验 +4. 提供 trace:extracted_slots、validation_pass/fail、ask_back_triggered +""" + +import logging +import time +from dataclasses import dataclass, field +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.mid.schemas import SlotSource +from app.services.mid.slot_backfill_service import ( + BackfillResult, + BackfillStatus, + SlotBackfillService, +) +from app.services.mid.slot_manager import SlotManager +from app.services.mid.slot_state_aggregator import SlotState, SlotStateAggregator +from app.services.mid.slot_strategy_executor import ( + ExtractContext, + SlotStrategyExecutor, + StrategyChainResult, +) +from app.services.slot_definition_service import SlotDefinitionService + +logger = logging.getLogger(__name__) + + +@dataclass +class ExtractionTrace: + """ + 提取追踪信息 + + Attributes: + slot_key: 槽位键名 + strategy: 使用的策略 + extracted_value: 提取的值 + validation_passed: 校验是否通过 + final_value: 最终值(校验后) + execution_time_ms: 执行时间 + failure_reason: 失败原因 + """ + slot_key: str + strategy: str | None = None + extracted_value: Any = None + validation_passed: bool = False + final_value: Any = None + execution_time_ms: float = 0.0 + failure_reason: str | None = None + + def to_dict(self) -> dict[str, Any]: + return { + "slot_key": self.slot_key, + "strategy": self.strategy, + "extracted_value": self.extracted_value, + "validation_passed": self.validation_passed, + "final_value": self.final_value, + "execution_time_ms": self.execution_time_ms, + "failure_reason": self.failure_reason, + } + + +@dataclass +class ExtractionResult: + """ + 提取结果 + + Attributes: + success: 是否成功 + extracted_slots: 成功提取的槽位 + failed_slots: 提取失败的槽位 + traces: 提取追踪信息列表 + total_execution_time_ms: 总执行时间 + ask_back_triggered: 是否触发追问 + ask_back_prompts: 追问提示列表 + """ + success: bool = False + extracted_slots: dict[str, Any] = field(default_factory=dict) + failed_slots: list[str] = field(default_factory=list) + traces: list[ExtractionTrace] = field(default_factory=list) + total_execution_time_ms: float = 0.0 + ask_back_triggered: bool = False + ask_back_prompts: list[str] = field(default_factory=list) + + def to_dict(self) -> dict[str, Any]: + return { + "success": self.success, + "extracted_slots": self.extracted_slots, + "failed_slots": self.failed_slots, + "traces": [t.to_dict() for t in self.traces], + "total_execution_time_ms": self.total_execution_time_ms, + "ask_back_triggered": self.ask_back_triggered, + "ask_back_prompts": self.ask_back_prompts, + } + + +class SlotExtractionIntegration: + """ + [AC-MRS-SLOT-EXTRACT-01] 槽位提取集成服务 + + 将自动提取能力接入主链路: + - 接入点:memory_recall 之后、KB 检索之前 + - 执行策略链:rule -> llm -> user_input + - 抽取结果统一走 SlotManager 校验 + - 提供 trace + """ + + DEFAULT_STRATEGIES = ["rule", "llm"] + + def __init__( + self, + session: AsyncSession, + tenant_id: str, + session_id: str | None = None, + slot_manager: SlotManager | None = None, + strategy_executor: SlotStrategyExecutor | None = None, + ): + self._session = session + self._tenant_id = tenant_id + self._session_id = session_id + self._slot_manager = slot_manager or SlotManager(session=session) + self._strategy_executor = strategy_executor or SlotStrategyExecutor() + self._slot_def_service = SlotDefinitionService(session) + self._backfill_service: SlotBackfillService | None = None + + async def _get_backfill_service(self) -> SlotBackfillService: + """获取回填服务""" + if self._backfill_service is None: + self._backfill_service = SlotBackfillService( + session=self._session, + tenant_id=self._tenant_id, + session_id=self._session_id, + slot_manager=self._slot_manager, + strategy_executor=self._strategy_executor, + ) + return self._backfill_service + + async def extract_and_fill( + self, + user_input: str, + target_slots: list[str] | None = None, + strategies: list[str] | None = None, + slot_state: SlotState | None = None, + ) -> ExtractionResult: + """ + 执行提取并填充槽位 + + Args: + user_input: 用户输入 + target_slots: 目标槽位列表(为空则提取所有必填槽位) + strategies: 提取策略链(默认 rule -> llm) + slot_state: 当前槽位状态(用于识别缺失槽位) + + Returns: + ExtractionResult: 提取结果 + """ + start_time = time.time() + + strategies = strategies or self.DEFAULT_STRATEGIES + + if target_slots is None: + target_slots = await self._get_missing_required_slots(slot_state) + + if not target_slots: + return ExtractionResult(success=True) + + result = ExtractionResult() + + for slot_key in target_slots: + trace = await self._extract_single_slot( + slot_key=slot_key, + user_input=user_input, + strategies=strategies, + ) + result.traces.append(trace) + + if trace.validation_passed and trace.final_value is not None: + result.extracted_slots[slot_key] = trace.final_value + else: + result.failed_slots.append(slot_key) + if trace.failure_reason: + ask_back_prompt = await self._slot_manager.get_ask_back_prompt( + self._tenant_id, slot_key + ) + if ask_back_prompt: + result.ask_back_prompts.append(ask_back_prompt) + + result.total_execution_time_ms = (time.time() - start_time) * 1000 + result.success = len(result.extracted_slots) > 0 + result.ask_back_triggered = len(result.ask_back_prompts) > 0 + + if result.extracted_slots and self._session_id: + await self._save_extracted_slots(result.extracted_slots) + + logger.info( + f"[AC-MRS-SLOT-EXTRACT-01] Extraction completed: " + f"tenant={self._tenant_id}, extracted={len(result.extracted_slots)}, " + f"failed={len(result.failed_slots)}, time_ms={result.total_execution_time_ms:.2f}" + ) + + return result + + async def _extract_single_slot( + self, + slot_key: str, + user_input: str, + strategies: list[str], + ) -> ExtractionTrace: + """提取单个槽位""" + start_time = time.time() + trace = ExtractionTrace(slot_key=slot_key) + + slot_def = await self._slot_def_service.get_slot_definition_by_key( + self._tenant_id, slot_key + ) + + if not slot_def: + trace.failure_reason = "Slot definition not found" + trace.execution_time_ms = (time.time() - start_time) * 1000 + return trace + + context = ExtractContext( + tenant_id=self._tenant_id, + slot_key=slot_key, + user_input=user_input, + slot_type=slot_def.type, + validation_rule=slot_def.validation_rule, + session_id=self._session_id, + ) + + chain_result = await self._strategy_executor.execute_chain( + strategies=strategies, + context=context, + ask_back_prompt=slot_def.ask_back_prompt, + ) + + trace.strategy = chain_result.final_strategy + trace.extracted_value = chain_result.final_value + trace.execution_time_ms = (time.time() - start_time) * 1000 + + if not chain_result.success: + trace.failure_reason = "Extraction failed" + if chain_result.steps: + last_step = chain_result.steps[-1] + trace.failure_reason = last_step.failure_reason + return trace + + backfill_service = await self._get_backfill_service() + source = self._get_source_for_strategy(chain_result.final_strategy) + + backfill_result = await backfill_service.backfill_single_slot( + slot_key=slot_key, + candidate_value=chain_result.final_value, + source=source, + confidence=self._get_confidence_for_source(source), + ) + + trace.validation_passed = backfill_result.is_success() + trace.final_value = backfill_result.normalized_value + + if not backfill_result.is_success(): + trace.failure_reason = backfill_result.error_message or "Validation failed" + + return trace + + async def _get_missing_required_slots( + self, + slot_state: SlotState | None, + ) -> list[str]: + """获取缺失的必填槽位""" + if slot_state and slot_state.missing_required_slots: + return [ + s.get("slot_key") + for s in slot_state.missing_required_slots + if s.get("slot_key") + ] + + required_defs = await self._slot_def_service.list_slot_definitions( + tenant_id=self._tenant_id, + required=True, + ) + + return [d.slot_key for d in required_defs] + + async def _save_extracted_slots( + self, + extracted_slots: dict[str, Any], + ) -> None: + """保存提取的槽位到缓存""" + if not self._session_id: + return + + aggregator = SlotStateAggregator( + session=self._session, + tenant_id=self._tenant_id, + session_id=self._session_id, + ) + + for slot_key, value in extracted_slots.items(): + await aggregator.update_slot( + slot_key=slot_key, + value=value, + source=SlotSource.RULE_EXTRACTED.value, + confidence=0.9, + ) + + def _get_source_for_strategy(self, strategy: str | None) -> str: + """根据策略获取来源""" + strategy_source_map = { + "rule": SlotSource.RULE_EXTRACTED.value, + "llm": SlotSource.LLM_INFERRED.value, + "user_input": SlotSource.USER_CONFIRMED.value, + } + return strategy_source_map.get(strategy or "", "unknown") + + def _get_confidence_for_source(self, source: str) -> float: + """根据来源获取置信度""" + confidence_map = { + SlotSource.USER_CONFIRMED.value: 1.0, + SlotSource.RULE_EXTRACTED.value: 0.9, + SlotSource.LLM_INFERRED.value: 0.7, + } + return confidence_map.get(source, 0.5) + + +async def integrate_slot_extraction( + session: AsyncSession, + tenant_id: str, + session_id: str, + user_input: str, + slot_state: SlotState | None = None, + strategies: list[str] | None = None, +) -> ExtractionResult: + """ + 便捷函数:集成槽位提取 + + Args: + session: 数据库会话 + tenant_id: 租户 ID + session_id: 会话 ID + user_input: 用户输入 + slot_state: 当前槽位状态 + strategies: 提取策略链 + + Returns: + ExtractionResult: 提取结果 + """ + integration = SlotExtractionIntegration( + session=session, + tenant_id=tenant_id, + session_id=session_id, + ) + + return await integration.extract_and_fill( + user_input=user_input, + slot_state=slot_state, + strategies=strategies, + ) diff --git a/ai-service/app/services/mid/slot_manager.py b/ai-service/app/services/mid/slot_manager.py new file mode 100644 index 0000000..2476ef9 --- /dev/null +++ b/ai-service/app/services/mid/slot_manager.py @@ -0,0 +1,379 @@ +""" +Slot Manager Service. +槽位管理服务 - 统一槽位写入入口,集成校验逻辑 + +职责: +1. 在槽位值写入前执行校验 +2. 管理槽位值的来源和置信度 +3. 提供槽位写入的统一接口 +4. 返回校验失败时的追问提示 +""" + +import logging +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.entities import SlotDefinition +from app.models.mid.schemas import SlotSource +from app.services.mid.slot_validation_service import ( + BatchValidationResult, + SlotValidationError, + SlotValidationService, +) +from app.services.slot_definition_service import SlotDefinitionService + +logger = logging.getLogger(__name__) + + +class SlotWriteResult: + """ + 槽位写入结果 + + Attributes: + success: 是否成功(校验通过并写入) + slot_key: 槽位键名 + value: 最终写入的值 + error: 校验错误信息(校验失败时) + ask_back_prompt: 追问提示语(校验失败时) + """ + + def __init__( + self, + success: bool, + slot_key: str, + value: Any | None = None, + error: SlotValidationError | None = None, + ask_back_prompt: str | None = None, + ): + self.success = success + self.slot_key = slot_key + self.value = value + self.error = error + self.ask_back_prompt = ask_back_prompt + + def to_dict(self) -> dict[str, Any]: + """转换为字典""" + result = { + "success": self.success, + "slot_key": self.slot_key, + "value": self.value, + } + if self.error: + result["error"] = { + "error_code": self.error.error_code, + "error_message": self.error.error_message, + } + if self.ask_back_prompt: + result["ask_back_prompt"] = self.ask_back_prompt + return result + + +class SlotManager: + """ + 槽位管理器 + + 统一槽位写入入口,在写入前执行校验。 + 支持从 SlotDefinition 加载校验规则并执行。 + """ + + def __init__( + self, + session: AsyncSession | None = None, + validation_service: SlotValidationService | None = None, + slot_def_service: SlotDefinitionService | None = None, + ): + """ + 初始化槽位管理器 + + Args: + session: 数据库会话 + validation_service: 校验服务实例 + slot_def_service: 槽位定义服务实例 + """ + self._session = session + self._validation_service = validation_service or SlotValidationService() + self._slot_def_service = slot_def_service + self._slot_def_cache: dict[str, SlotDefinition | None] = {} + + async def write_slot( + self, + tenant_id: str, + slot_key: str, + value: Any, + source: SlotSource = SlotSource.USER_CONFIRMED, + confidence: float = 1.0, + skip_validation: bool = False, + ) -> SlotWriteResult: + """ + 写入单个槽位值(带校验) + + 执行流程: + 1. 加载槽位定义 + 2. 执行校验(如果未跳过) + 3. 返回校验结果 + + Args: + tenant_id: 租户 ID + slot_key: 槽位键名 + value: 槽位值 + source: 值来源 + confidence: 置信度 + skip_validation: 是否跳过校验(用于特殊场景) + + Returns: + SlotWriteResult: 写入结果 + """ + # 加载槽位定义 + slot_def = await self._get_slot_definition(tenant_id, slot_key) + + # 如果没有定义且非跳过校验,允许写入(动态槽位) + if slot_def is None and skip_validation: + logger.debug( + f"[SlotManager] Writing slot without definition: " + f"tenant_id={tenant_id}, slot_key={slot_key}" + ) + return SlotWriteResult( + success=True, + slot_key=slot_key, + value=value, + ) + + if slot_def is None: + # 未定义槽位,允许写入但记录日志 + logger.info( + f"[SlotManager] Slot definition not found, allowing write: " + f"tenant_id={tenant_id}, slot_key={slot_key}" + ) + return SlotWriteResult( + success=True, + slot_key=slot_key, + value=value, + ) + + # 执行校验 + if not skip_validation: + validation_result = self._validation_service.validate_slot_value( + slot_def, value, tenant_id + ) + + if not validation_result.ok: + logger.info( + f"[SlotManager] Slot validation failed: " + f"tenant_id={tenant_id}, slot_key={slot_key}, " + f"error_code={validation_result.error_code}" + ) + return SlotWriteResult( + success=False, + slot_key=slot_key, + error=SlotValidationError( + slot_key=slot_key, + error_code=validation_result.error_code or "VALIDATION_FAILED", + error_message=validation_result.error_message or "校验失败", + ask_back_prompt=validation_result.ask_back_prompt, + ), + ask_back_prompt=validation_result.ask_back_prompt, + ) + + # 使用归一化后的值 + value = validation_result.normalized_value + + logger.debug( + f"[SlotManager] Slot validation passed: " + f"tenant_id={tenant_id}, slot_key={slot_key}" + ) + + return SlotWriteResult( + success=True, + slot_key=slot_key, + value=value, + ) + + async def write_slots( + self, + tenant_id: str, + values: dict[str, Any], + source: SlotSource = SlotSource.USER_CONFIRMED, + confidence: float = 1.0, + skip_validation: bool = False, + ) -> BatchValidationResult: + """ + 批量写入槽位值(带校验) + + Args: + tenant_id: 租户 ID + values: 槽位值字典 {slot_key: value} + source: 值来源 + confidence: 置信度 + skip_validation: 是否跳过校验 + + Returns: + BatchValidationResult: 批量校验结果 + """ + if skip_validation: + return BatchValidationResult( + ok=True, + validated_values=values, + ) + + # 加载所有相关槽位定义 + slot_defs = await self._get_slot_definitions(tenant_id, list(values.keys())) + + # 执行批量校验 + result = self._validation_service.validate_slots( + slot_defs, values, tenant_id + ) + + if not result.ok: + logger.info( + f"[SlotManager] Batch slot validation failed: " + f"tenant_id={tenant_id}, errors={[e.slot_key for e in result.errors]}" + ) + else: + logger.debug( + f"[SlotManager] Batch slot validation passed: " + f"tenant_id={tenant_id}, slots={list(values.keys())}" + ) + + return result + + async def validate_before_write( + self, + tenant_id: str, + slot_key: str, + value: Any, + ) -> tuple[bool, SlotValidationError | None]: + """ + 在写入前预校验槽位值 + + Args: + tenant_id: 租户 ID + slot_key: 槽位键名 + value: 槽位值 + + Returns: + Tuple of (是否通过, 错误信息) + """ + slot_def = await self._get_slot_definition(tenant_id, slot_key) + + if slot_def is None: + # 未定义槽位,视为通过 + return True, None + + result = self._validation_service.validate_slot_value( + slot_def, value, tenant_id + ) + + if result.ok: + return True, None + + return False, SlotValidationError( + slot_key=slot_key, + error_code=result.error_code or "VALIDATION_FAILED", + error_message=result.error_message or "校验失败", + ask_back_prompt=result.ask_back_prompt, + ) + + async def get_ask_back_prompt( + self, + tenant_id: str, + slot_key: str, + ) -> str | None: + """ + 获取槽位的追问提示语 + + Args: + tenant_id: 租户 ID + slot_key: 槽位键名 + + Returns: + 追问提示语或 None + """ + slot_def = await self._get_slot_definition(tenant_id, slot_key) + if slot_def is None: + return None + + if isinstance(slot_def, SlotDefinition): + return slot_def.ask_back_prompt + return slot_def.get("ask_back_prompt") + + async def _get_slot_definition( + self, + tenant_id: str, + slot_key: str, + ) -> SlotDefinition | dict[str, Any] | None: + """ + 获取槽位定义(带缓存) + + Args: + tenant_id: 租户 ID + slot_key: 槽位键名 + + Returns: + 槽位定义或 None + """ + cache_key = f"{tenant_id}:{slot_key}" + + if cache_key in self._slot_def_cache: + return self._slot_def_cache[cache_key] + + slot_def = None + if self._slot_def_service: + slot_def = await self._slot_def_service.get_slot_definition_by_key( + tenant_id, slot_key + ) + elif self._session: + service = SlotDefinitionService(self._session) + slot_def = await service.get_slot_definition_by_key(tenant_id, slot_key) + + self._slot_def_cache[cache_key] = slot_def + return slot_def + + async def _get_slot_definitions( + self, + tenant_id: str, + slot_keys: list[str], + ) -> list[SlotDefinition | dict[str, Any]]: + """ + 批量获取槽位定义 + + Args: + tenant_id: 租户 ID + slot_keys: 槽位键名列表 + + Returns: + 槽位定义列表 + """ + slot_defs = [] + for key in slot_keys: + slot_def = await self._get_slot_definition(tenant_id, key) + if slot_def: + slot_defs.append(slot_def) + return slot_defs + + def clear_cache(self) -> None: + """清除槽位定义缓存""" + self._slot_def_cache.clear() + + +def create_slot_manager( + session: AsyncSession | None = None, + validation_service: SlotValidationService | None = None, + slot_def_service: SlotDefinitionService | None = None, +) -> SlotManager: + """ + 创建槽位管理器实例 + + Args: + session: 数据库会话 + validation_service: 校验服务实例 + slot_def_service: 槽位定义服务实例 + + Returns: + SlotManager: 槽位管理器实例 + """ + return SlotManager( + session=session, + validation_service=validation_service, + slot_def_service=slot_def_service, + ) diff --git a/ai-service/app/services/mid/slot_state_aggregator.py b/ai-service/app/services/mid/slot_state_aggregator.py new file mode 100644 index 0000000..188dec6 --- /dev/null +++ b/ai-service/app/services/mid/slot_state_aggregator.py @@ -0,0 +1,562 @@ +""" +Slot State Aggregator Service. +槽位状态聚合服务 - 统一维护本轮槽位状态 + +职责: +1. 聚合来自 memory_recall 的槽位值 +2. 叠加本轮输入的槽位值 +3. 识别缺失的必填槽位 +4. 支持槽位与元数据字段的关联映射 +5. 为 KB 检索过滤提供统一的槽位值来源 +6. [AC-MRS-SLOT-CACHE-01] 多轮状态持久化 + +[AC-MRS-SLOT-META-01] 槽位与元数据关联机制 +""" + +import logging +from dataclasses import dataclass, field +from typing import Any + +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.mid.schemas import MemorySlot, SlotSource +from app.services.cache.slot_state_cache import ( + CachedSlotState, + CachedSlotValue, + get_slot_state_cache, +) +from app.services.mid.slot_manager import SlotManager +from app.services.slot_definition_service import SlotDefinitionService + +logger = logging.getLogger(__name__) + + +@dataclass +class SlotState: + """ + 槽位状态聚合结果 + + Attributes: + filled_slots: 已填充的槽位值字典 {slot_key: value} + missing_required_slots: 缺失的必填槽位列表 + slot_sources: 槽位值来源字典 {slot_key: source} + slot_confidence: 槽位置信度字典 {slot_key: confidence} + slot_to_field_map: 槽位到元数据字段的映射 {slot_key: field_key} + """ + filled_slots: dict[str, Any] = field(default_factory=dict) + missing_required_slots: list[dict[str, str]] = field(default_factory=list) + slot_sources: dict[str, str] = field(default_factory=dict) + slot_confidence: dict[str, float] = field(default_factory=dict) + slot_to_field_map: dict[str, str] = field(default_factory=dict) + + def get_value_for_filter(self, field_key: str) -> Any: + """ + 获取用于 KB 过滤的字段值 + + 优先从 slot_to_field_map 反向查找 + """ + # 直接匹配 + if field_key in self.filled_slots: + return self.filled_slots[field_key] + + # 通过 slot_to_field_map 反向查找 + for slot_key, mapped_field_key in self.slot_to_field_map.items(): + if mapped_field_key == field_key and slot_key in self.filled_slots: + return self.filled_slots[slot_key] + + return None + + def to_debug_info(self) -> dict[str, Any]: + """转换为调试信息字典""" + return { + "filled_slots": self.filled_slots, + "missing_required_slots": self.missing_required_slots, + "slot_sources": self.slot_sources, + "slot_to_field_map": self.slot_to_field_map, + } + + +class SlotStateAggregator: + """ + [AC-MRS-SLOT-META-01] 槽位状态聚合器 + + 统一维护本轮槽位状态,支持: + - 从 memory_recall 初始化槽位 + - 叠加本轮输入的槽位值 + - 识别缺失的必填槽位 + - 建立槽位与元数据字段的关联 + - [AC-MRS-SLOT-CACHE-01] 多轮状态持久化 + """ + + def __init__( + self, + session: AsyncSession, + tenant_id: str, + slot_manager: SlotManager | None = None, + session_id: str | None = None, + ): + self._session = session + self._tenant_id = tenant_id + self._session_id = session_id + self._slot_manager = slot_manager or SlotManager(session=session) + self._slot_def_service = SlotDefinitionService(session) + self._cache = get_slot_state_cache() + + async def aggregate( + self, + memory_slots: dict[str, MemorySlot] | None = None, + current_input_slots: dict[str, Any] | None = None, + context: dict[str, Any] | None = None, + use_cache: bool = True, + scene_slot_context: Any = None, # [AC-SCENE-SLOT-02] 场景槽位上下文 + ) -> SlotState: + """ + 聚合槽位状态 + + 执行流程: + 1. [AC-MRS-SLOT-CACHE-01] 从缓存加载已有状态 + 2. 从 memory_slots 初始化已填充槽位 + 3. 叠加 current_input_slots(优先级更高) + 4. 从 context 提取槽位值 + 5. 识别缺失的必填槽位 + 6. 建立槽位与元数据字段的关联映射 + 7. [AC-MRS-SLOT-CACHE-01] 回写缓存 + + Args: + memory_slots: 从 memory_recall 召回的槽位 + current_input_slots: 本轮输入的槽位值 + context: 上下文信息,可能包含槽位值 + use_cache: 是否使用缓存(默认 True) + + Returns: + SlotState: 聚合后的槽位状态 + """ + state = SlotState() + + # [AC-MRS-SLOT-CACHE-01] 1. 从缓存加载已有状态 + cached_state = None + if use_cache and self._session_id: + cached_state = await self._cache.get(self._tenant_id, self._session_id) + if cached_state: + for slot_key, cached_value in cached_state.filled_slots.items(): + state.filled_slots[slot_key] = cached_value.value + state.slot_sources[slot_key] = cached_value.source + state.slot_confidence[slot_key] = cached_value.confidence + state.slot_to_field_map = cached_state.slot_to_field_map.copy() + logger.info( + f"[AC-MRS-SLOT-CACHE-01] Loaded from cache: " + f"tenant={self._tenant_id}, session={self._session_id}, " + f"slots={list(state.filled_slots.keys())}" + ) + + # 2. 从 memory_slots 初始化 + if memory_slots: + for slot_key, memory_slot in memory_slots.items(): + state.filled_slots[slot_key] = memory_slot.value + state.slot_sources[slot_key] = memory_slot.source.value + state.slot_confidence[slot_key] = memory_slot.confidence + logger.info( + f"[AC-MRS-SLOT-META-01] Initialized from memory: " + f"tenant={self._tenant_id}, slots={list(memory_slots.keys())}" + ) + + # 3. 叠加本轮输入(优先级更高) + if current_input_slots: + for slot_key, value in current_input_slots.items(): + if value is not None: + state.filled_slots[slot_key] = value + state.slot_sources[slot_key] = SlotSource.USER_CONFIRMED.value + state.slot_confidence[slot_key] = 1.0 + logger.info( + f"[AC-MRS-SLOT-META-01] Merged current input: " + f"tenant={self._tenant_id}, slots={list(current_input_slots.keys())}" + ) + + # 4. 从 context 提取槽位值(优先级最低) + if context: + context_slots = self._extract_slots_from_context(context) + for slot_key, value in context_slots.items(): + if slot_key not in state.filled_slots and value is not None: + state.filled_slots[slot_key] = value + state.slot_sources[slot_key] = "context" + state.slot_confidence[slot_key] = 0.5 + + # 5. 加载槽位定义并建立关联 + await self._build_slot_mappings(state) + + # 6. 识别缺失的必填槽位 + await self._identify_missing_required_slots(state, scene_slot_context) + + # [AC-MRS-SLOT-CACHE-01] 7. 回写缓存 + if use_cache and self._session_id: + await self._save_to_cache(state) + + logger.info( + f"[AC-MRS-SLOT-META-01] Slot state aggregated: " + f"tenant={self._tenant_id}, filled={len(state.filled_slots)}, " + f"missing={len(state.missing_required_slots)}" + ) + + return state + + async def _save_to_cache(self, state: SlotState) -> None: + """ + [AC-MRS-SLOT-CACHE-01] 保存槽位状态到缓存 + """ + if not self._session_id: + return + + cached_slots = {} + for slot_key, value in state.filled_slots.items(): + source = state.slot_sources.get(slot_key, "unknown") + confidence = state.slot_confidence.get(slot_key, 1.0) + cached_slots[slot_key] = CachedSlotValue( + value=value, + source=source, + confidence=confidence, + ) + + cached_state = CachedSlotState( + filled_slots=cached_slots, + slot_to_field_map=state.slot_to_field_map.copy(), + ) + + await self._cache.set(self._tenant_id, self._session_id, cached_state) + logger.debug( + f"[AC-MRS-SLOT-CACHE-01] Saved to cache: " + f"tenant={self._tenant_id}, session={self._session_id}" + ) + + async def update_slot( + self, + slot_key: str, + value: Any, + source: str = "user_confirmed", + confidence: float = 1.0, + ) -> SlotState | None: + """ + [AC-MRS-SLOT-CACHE-01] 更新单个槽位值并保存到缓存 + + Args: + slot_key: 槽位键名 + value: 槽位值 + source: 值来源 + confidence: 置信度 + + Returns: + 更新后的槽位状态,如果没有 session_id 则返回 None + """ + if not self._session_id: + return None + + cached_value = CachedSlotValue( + value=value, + source=source, + confidence=confidence, + ) + + cached_state = await self._cache.merge_and_set( + tenant_id=self._tenant_id, + session_id=self._session_id, + new_slots={slot_key: cached_value}, + ) + + state = SlotState() + state.filled_slots = cached_state.get_simple_filled_slots() + state.slot_sources = cached_state.get_slot_sources() + state.slot_confidence = cached_state.get_slot_confidence() + state.slot_to_field_map = cached_state.slot_to_field_map.copy() + + await self._identify_missing_required_slots(state) + + return state + + async def clear_slot(self, slot_key: str) -> bool: + """ + [AC-MRS-SLOT-CACHE-01] 清除单个槽位值 + + Args: + slot_key: 槽位键名 + + Returns: + 是否成功 + """ + if not self._session_id: + return False + + return await self._cache.clear_slot( + tenant_id=self._tenant_id, + session_id=self._session_id, + slot_key=slot_key, + ) + + async def clear_all_slots(self) -> bool: + """ + [AC-MRS-SLOT-CACHE-01] 清除所有槽位状态 + + Returns: + 是否成功 + """ + if not self._session_id: + return False + + return await self._cache.delete( + tenant_id=self._tenant_id, + session_id=self._session_id, + ) + + def _extract_slots_from_context(self, context: dict[str, Any]) -> dict[str, Any]: + """ + 从上下文中提取可能的槽位值 + + 常见的槽位值可能存在于: + - scene + - product_line + - region + - grade + 等字段 + """ + slots = {} + slot_candidates = [ + "scene", "product_line", "region", "grade", + "category", "type", "status", "priority" + ] + + for key in slot_candidates: + if key in context and context[key] is not None: + slots[key] = context[key] + + return slots + + async def _build_slot_mappings(self, state: SlotState) -> None: + """ + 建立槽位与元数据字段的关联映射 + + 通过 linked_field_id 关联 SlotDefinition 和 MetadataFieldDefinition + """ + try: + # 获取所有槽位定义 + slot_defs = await self._slot_def_service.list_slot_definitions( + tenant_id=self._tenant_id + ) + + for slot_def in slot_defs: + if slot_def.linked_field_id: + # 获取关联的元数据字段 + from app.services.metadata_field_definition_service import ( + MetadataFieldDefinitionService + ) + field_service = MetadataFieldDefinitionService(self._session) + linked_field = await field_service.get_field_definition( + tenant_id=self._tenant_id, + field_id=str(slot_def.linked_field_id) + ) + + if linked_field: + state.slot_to_field_map[slot_def.slot_key] = linked_field.field_key + + # 检查类型一致性并告警 + await self._check_type_consistency( + slot_def, linked_field, state + ) + except Exception as e: + logger.warning( + f"[AC-MRS-SLOT-META-01] Failed to build slot mappings: {e}" + ) + + async def _check_type_consistency( + self, + slot_def: Any, + linked_field: Any, + state: SlotState, + ) -> None: + """ + 检查槽位与关联元数据字段的类型一致性 + + 当不一致时记录告警日志(不强拦截) + """ + # 检查类型一致性 + if slot_def.type != linked_field.type: + logger.warning( + f"[AC-MRS-SLOT-META-01] Type mismatch: " + f"slot='{slot_def.slot_key}' type={slot_def.type} vs " + f"field='{linked_field.field_key}' type={linked_field.type}" + ) + + # 检查 required 一致性 + if slot_def.required != linked_field.required: + logger.warning( + f"[AC-MRS-SLOT-META-01] Required mismatch: " + f"slot='{slot_def.slot_key}' required={slot_def.required} vs " + f"field='{linked_field.field_key}' required={linked_field.required}" + ) + + # 检查 options 一致性(对于 enum/array_enum 类型) + if slot_def.type in ["enum", "array_enum"]: + slot_options = set(slot_def.options or []) + field_options = set(linked_field.options or []) + if slot_options != field_options: + logger.warning( + f"[AC-MRS-SLOT-META-01] Options mismatch: " + f"slot='{slot_def.slot_key}' options={slot_options} vs " + f"field='{linked_field.field_key}' options={field_options}" + ) + + async def _identify_missing_required_slots( + self, + state: SlotState, + scene_slot_context: Any = None, # [AC-SCENE-SLOT-02] 场景槽位上下文 + ) -> None: + """ + 识别缺失的必填槽位 + + 基于 SlotDefinition 的 required 字段和 linked_field 的 required 字段 + [AC-SCENE-SLOT-02] 当有场景槽位上下文时,优先使用场景定义的必填槽位 + """ + try: + # [AC-SCENE-SLOT-02] 如果有场景槽位上下文,使用场景定义的必填槽位 + if scene_slot_context: + scene_required_keys = set(scene_slot_context.get_required_slot_keys()) + logger.info( + f"[AC-SCENE-SLOT-02] Using scene required slots: " + f"scene={scene_slot_context.scene_key}, " + f"required_keys={scene_required_keys}" + ) + + # 获取场景中定义的所有槽位 + all_slot_defs = await self._slot_def_service.list_slot_definitions( + tenant_id=self._tenant_id, + ) + slot_def_map = {slot.slot_key: slot for slot in all_slot_defs} + + for slot_key in scene_required_keys: + if slot_key not in state.filled_slots: + slot_def = slot_def_map.get(slot_key) + ask_back_prompt = slot_def.ask_back_prompt if slot_def else None + + if not ask_back_prompt: + ask_back_prompt = f"请提供{slot_key}信息" + + missing_info = { + "slot_key": slot_key, + "label": slot_key, + "reason": "scene_required_slot_missing", + "ask_back_prompt": ask_back_prompt, + "scene": scene_slot_context.scene_key, + } + + if slot_def and slot_def.linked_field_id: + missing_info["linked_field_id"] = str(slot_def.linked_field_id) + + state.missing_required_slots.append(missing_info) + + return + + # 获取所有 required 的槽位定义 + required_slot_defs = await self._slot_def_service.list_slot_definitions( + tenant_id=self._tenant_id, + required=True, + ) + + for slot_def in required_slot_defs: + if slot_def.slot_key not in state.filled_slots: + # 获取追问提示 + ask_back_prompt = slot_def.ask_back_prompt + + # 如果没有配置追问提示,使用通用模板 + if not ask_back_prompt: + ask_back_prompt = f"请提供{slot_def.slot_key}信息" + + missing_info = { + "slot_key": slot_def.slot_key, + "label": slot_def.slot_key, + "reason": "required_slot_missing", + "ask_back_prompt": ask_back_prompt, + } + + # 如果有关联字段,使用字段的 label + if slot_def.linked_field_id: + from app.services.metadata_field_definition_service import ( + MetadataFieldDefinitionService + ) + field_service = MetadataFieldDefinitionService(self._session) + linked_field = await field_service.get_field_definition( + tenant_id=self._tenant_id, + field_id=str(slot_def.linked_field_id) + ) + if linked_field: + missing_info["label"] = linked_field.label + missing_info["field_key"] = linked_field.field_key + + state.missing_required_slots.append(missing_info) + + logger.info( + f"[AC-MRS-SLOT-META-01] Missing required slot: " + f"slot_key={slot_def.slot_key}" + ) + except Exception as e: + logger.warning( + f"[AC-MRS-SLOT-META-01] Failed to identify missing slots: {e}" + ) + + async def generate_ask_back_response( + self, + state: SlotState, + missing_slot_key: str | None = None, + ) -> str | None: + """ + 生成追问响应文案 + + Args: + state: 当前槽位状态 + missing_slot_key: 指定要追问的槽位键名,为 None 时追问第一个缺失槽位 + + Returns: + 追问文案或 None(如果没有缺失槽位) + """ + if not state.missing_required_slots: + return None + + missing_info = None + + # 如果指定了槽位键名,查找对应的追问提示 + if missing_slot_key: + for missing in state.missing_required_slots: + if missing.get("slot_key") == missing_slot_key: + missing_info = missing + break + else: + # 使用第一个缺失槽位 + missing_info = state.missing_required_slots[0] + + if missing_info is None: + return None + + # 优先使用配置的 ask_back_prompt + ask_back_prompt = missing_info.get("ask_back_prompt") + if ask_back_prompt: + return ask_back_prompt + + # 使用通用模板 + label = missing_info.get("label", missing_info.get("slot_key", "相关信息")) + return f"为了更好地为您提供帮助,请告诉我您的{label}。" + + +def create_slot_state_aggregator( + session: AsyncSession, + tenant_id: str, +) -> SlotStateAggregator: + """ + 创建槽位状态聚合器实例 + + Args: + session: 数据库会话 + tenant_id: 租户 ID + + Returns: + SlotStateAggregator: 槽位状态聚合器实例 + """ + return SlotStateAggregator( + session=session, + tenant_id=tenant_id, + ) diff --git a/ai-service/app/services/mid/slot_strategy_executor.py b/ai-service/app/services/mid/slot_strategy_executor.py new file mode 100644 index 0000000..c959e01 --- /dev/null +++ b/ai-service/app/services/mid/slot_strategy_executor.py @@ -0,0 +1,355 @@ +""" +Slot Strategy Executor. +[AC-MRS-07-UPGRADE] 槽位提取策略链执行器 + +按顺序执行提取策略链,直到成功提取并通过校验。 +支持失败分类和详细日志追踪。 +""" + +import logging +from dataclasses import dataclass, field +from enum import Enum +from typing import Any, Callable + +from app.models.entities import ExtractFailureType + +logger = logging.getLogger(__name__) + + +class ExtractStrategyType(str, Enum): + """提取策略类型""" + RULE = "rule" + LLM = "llm" + USER_INPUT = "user_input" + + +@dataclass +class StrategyStepResult: + """单步策略执行结果""" + strategy: str + success: bool + value: Any = None + failure_type: ExtractFailureType | None = None + failure_reason: str = "" + execution_time_ms: float = 0.0 + + +@dataclass +class StrategyChainResult: + """策略链执行结果""" + slot_key: str + success: bool + final_value: Any = None + final_strategy: str | None = None + steps: list[StrategyStepResult] = field(default_factory=list) + total_execution_time_ms: float = 0.0 + ask_back_prompt: str | None = None + + def to_dict(self) -> dict[str, Any]: + """转换为字典""" + return { + "slot_key": self.slot_key, + "success": self.success, + "final_value": self.final_value, + "final_strategy": self.final_strategy, + "steps": [ + { + "strategy": step.strategy, + "success": step.success, + "value": step.value if step.success else None, + "failure_type": step.failure_type.value if step.failure_type else None, + "failure_reason": step.failure_reason, + "execution_time_ms": step.execution_time_ms, + } + for step in self.steps + ], + "total_execution_time_ms": self.total_execution_time_ms, + "ask_back_prompt": self.ask_back_prompt, + } + + +@dataclass +class ExtractContext: + """提取上下文""" + tenant_id: str + slot_key: str + user_input: str + slot_type: str + validation_rule: str | None = None + history: list[dict[str, str]] | None = None + session_id: str | None = None + + +class SlotStrategyExecutor: + """ + [AC-MRS-07-UPGRADE] 槽位提取策略链执行器 + + 职责: + 1. 按策略链顺序执行提取 + 2. 某一步成功且校验通过 -> 停止并返回结果 + 3. 当前策略失败 -> 记录失败原因,继续下一策略 + 4. 全部失败 -> 返回结构化失败结果 + 5. 提供可追踪日志(slot_key、strategy、reason) + """ + + def __init__( + self, + rule_extractor: Callable[[ExtractContext], Any] | None = None, + llm_extractor: Callable[[ExtractContext], Any] | None = None, + user_input_extractor: Callable[[ExtractContext], Any] | None = None, + ): + """ + 初始化执行器 + + Args: + rule_extractor: 规则提取器函数 + llm_extractor: LLM提取器函数 + user_input_extractor: 用户输入提取器函数 + """ + self._extractors: dict[str, Callable[[ExtractContext], Any]] = { + ExtractStrategyType.RULE.value: rule_extractor or self._default_rule_extract, + ExtractStrategyType.LLM.value: llm_extractor or self._default_llm_extract, + ExtractStrategyType.USER_INPUT.value: user_input_extractor or self._default_user_input_extract, + } + + async def execute_chain( + self, + strategies: list[str], + context: ExtractContext, + ask_back_prompt: str | None = None, + ) -> StrategyChainResult: + """ + 执行提取策略链 + + Args: + strategies: 策略链,如 ["user_input", "rule", "llm"] + context: 提取上下文 + ask_back_prompt: 追问提示语(全部失败时使用) + + Returns: + StrategyChainResult: 执行结果 + """ + import time + + start_time = time.time() + steps: list[StrategyStepResult] = [] + + logger.info( + f"[SlotStrategyExecutor] Starting strategy chain for slot '{context.slot_key}': " + f"strategies={strategies}, tenant={context.tenant_id}" + ) + + for idx, strategy in enumerate(strategies): + step_start = time.time() + + logger.info( + f"[SlotStrategyExecutor] Executing step {idx + 1}/{len(strategies)}: " + f"slot_key={context.slot_key}, strategy={strategy}" + ) + + step_result = await self._execute_single_strategy(strategy, context) + step_result.execution_time_ms = (time.time() - step_start) * 1000 + steps.append(step_result) + + if step_result.success: + total_time = (time.time() - start_time) * 1000 + logger.info( + f"[SlotStrategyExecutor] Strategy chain succeeded at step {idx + 1}: " + f"slot_key={context.slot_key}, strategy={strategy}, " + f"total_time_ms={total_time:.2f}" + ) + return StrategyChainResult( + slot_key=context.slot_key, + success=True, + final_value=step_result.value, + final_strategy=strategy, + steps=steps, + total_execution_time_ms=total_time, + ) + else: + logger.warning( + f"[SlotStrategyExecutor] Step {idx + 1} failed: " + f"slot_key={context.slot_key}, strategy={strategy}, " + f"failure_type={step_result.failure_type}, " + f"reason={step_result.failure_reason}" + ) + + # 全部策略失败 + total_time = (time.time() - start_time) * 1000 + logger.warning( + f"[SlotStrategyExecutor] All strategies failed for slot '{context.slot_key}': " + f"attempted={len(strategies)}, total_time_ms={total_time:.2f}" + ) + + return StrategyChainResult( + slot_key=context.slot_key, + success=False, + final_value=None, + final_strategy=None, + steps=steps, + total_execution_time_ms=total_time, + ask_back_prompt=ask_back_prompt, + ) + + async def _execute_single_strategy( + self, + strategy: str, + context: ExtractContext, + ) -> StrategyStepResult: + """ + 执行单个提取策略 + + Args: + strategy: 策略类型 + context: 提取上下文 + + Returns: + StrategyStepResult: 单步执行结果 + """ + extractor = self._extractors.get(strategy) + + if not extractor: + return StrategyStepResult( + strategy=strategy, + success=False, + failure_type=ExtractFailureType.EXTRACT_RUNTIME_ERROR, + failure_reason=f"Unknown strategy: {strategy}", + ) + + try: + # 执行提取 + value = await extractor(context) + + # 检查结果是否为空 + if value is None or value == "": + return StrategyStepResult( + strategy=strategy, + success=False, + failure_type=ExtractFailureType.EXTRACT_EMPTY, + failure_reason="Extracted value is empty", + ) + + # 执行校验(如果有校验规则) + if context.validation_rule: + is_valid, error_msg = self._validate_value(value, context) + if not is_valid: + return StrategyStepResult( + strategy=strategy, + success=False, + failure_type=ExtractFailureType.EXTRACT_VALIDATION_FAIL, + failure_reason=f"Validation failed: {error_msg}", + ) + + return StrategyStepResult( + strategy=strategy, + success=True, + value=value, + ) + + except Exception as e: + logger.exception( + f"[SlotStrategyExecutor] Runtime error in strategy '{strategy}' " + f"for slot '{context.slot_key}': {e}" + ) + return StrategyStepResult( + strategy=strategy, + success=False, + failure_type=ExtractFailureType.EXTRACT_RUNTIME_ERROR, + failure_reason=f"Runtime error: {str(e)}", + ) + + def _validate_value(self, value: Any, context: ExtractContext) -> tuple[bool, str]: + """ + 校验提取的值 + + Args: + value: 提取的值 + context: 提取上下文 + + Returns: + Tuple of (是否通过, 错误信息) + """ + import re + + validation_rule = context.validation_rule + if not validation_rule: + return True, "" + + try: + # 尝试作为正则表达式校验 + if validation_rule.startswith("^") or validation_rule.endswith("$"): + if re.match(validation_rule, str(value)): + return True, "" + return False, f"Value '{value}' does not match pattern '{validation_rule}'" + + # 其他校验规则可以在这里扩展 + return True, "" + + except re.error as e: + logger.warning(f"[SlotStrategyExecutor] Invalid validation rule pattern: {e}") + return True, "" # 正则错误时放行 + + async def _default_rule_extract(self, context: ExtractContext) -> Any: + """默认规则提取实现(占位)""" + # 实际项目中应该调用 VariableExtractor 或其他规则引擎 + logger.debug(f"[SlotStrategyExecutor] Default rule extract for '{context.slot_key}'") + return None + + async def _default_llm_extract(self, context: ExtractContext) -> Any: + """默认LLM提取实现(占位)""" + logger.debug(f"[SlotStrategyExecutor] Default LLM extract for '{context.slot_key}'") + return None + + async def _default_user_input_extract(self, context: ExtractContext) -> Any: + """默认用户输入提取实现""" + # user_input 策略通常表示需要向用户询问,这里返回空表示需要追问 + logger.debug(f"[SlotStrategyExecutor] User input required for '{context.slot_key}'") + return None + + +# 便捷函数 +async def execute_extract_strategies( + strategies: list[str], + tenant_id: str, + slot_key: str, + user_input: str, + slot_type: str = "string", + validation_rule: str | None = None, + ask_back_prompt: str | None = None, + history: list[dict[str, str]] | None = None, + rule_extractor: Callable[[ExtractContext], Any] | None = None, + llm_extractor: Callable[[ExtractContext], Any] | None = None, +) -> StrategyChainResult: + """ + 便捷函数:执行提取策略链 + + Args: + strategies: 策略链 + tenant_id: 租户ID + slot_key: 槽位键名 + user_input: 用户输入 + slot_type: 槽位类型 + validation_rule: 校验规则 + ask_back_prompt: 追问提示语 + history: 对话历史 + rule_extractor: 规则提取器 + llm_extractor: LLM提取器 + + Returns: + StrategyChainResult: 执行结果 + """ + executor = SlotStrategyExecutor( + rule_extractor=rule_extractor, + llm_extractor=llm_extractor, + ) + + context = ExtractContext( + tenant_id=tenant_id, + slot_key=slot_key, + user_input=user_input, + slot_type=slot_type, + validation_rule=validation_rule, + history=history, + ) + + return await executor.execute_chain(strategies, context, ask_back_prompt) diff --git a/ai-service/app/services/mid/slot_validation_service.py b/ai-service/app/services/mid/slot_validation_service.py new file mode 100644 index 0000000..6b45176 --- /dev/null +++ b/ai-service/app/services/mid/slot_validation_service.py @@ -0,0 +1,572 @@ +""" +Slot Validation Service. +槽位校验规则 runtime 生效服务 + +提供槽位值的运行时校验能力,支持: +1. 正则表达式校验 +2. JSON Schema 校验 +3. 类型校验(string/number/boolean/enum/array_enum) +4. 必填校验 +""" + +import json +import logging +import re +from dataclasses import dataclass, field +from typing import Any + +import jsonschema +from jsonschema.exceptions import ValidationError as JsonSchemaValidationError + +from app.models.entities import SlotDefinition + +logger = logging.getLogger(__name__) + + +# 错误码定义 +class SlotValidationErrorCode: + """槽位校验错误码""" + + SLOT_REQUIRED_MISSING = "SLOT_REQUIRED_MISSING" + SLOT_TYPE_INVALID = "SLOT_TYPE_INVALID" + SLOT_REGEX_MISMATCH = "SLOT_REGEX_MISMATCH" + SLOT_JSON_SCHEMA_MISMATCH = "SLOT_JSON_SCHEMA_MISMATCH" + SLOT_VALIDATION_RULE_INVALID = "SLOT_VALIDATION_RULE_INVALID" + SLOT_ENUM_INVALID = "SLOT_ENUM_INVALID" + SLOT_ARRAY_ENUM_INVALID = "SLOT_ARRAY_ENUM_INVALID" + + +@dataclass +class ValidationResult: + """ + 单个槽位校验结果 + + Attributes: + ok: 校验是否通过 + normalized_value: 归一化后的值(如类型转换后) + error_code: 错误码(校验失败时) + error_message: 错误描述(校验失败时) + ask_back_prompt: 追问提示语(校验失败且配置了 ask_back_prompt 时) + """ + + ok: bool + normalized_value: Any | None = None + error_code: str | None = None + error_message: str | None = None + ask_back_prompt: str | None = None + + +@dataclass +class SlotValidationError: + """ + 槽位校验错误详情 + + Attributes: + slot_key: 槽位键名 + error_code: 错误码 + error_message: 错误描述 + ask_back_prompt: 追问提示语 + """ + + slot_key: str + error_code: str + error_message: str + ask_back_prompt: str | None = None + + +@dataclass +class BatchValidationResult: + """ + 批量槽位校验结果 + + Attributes: + ok: 是否全部校验通过 + errors: 校验错误列表 + validated_values: 校验通过的值字典 + """ + + ok: bool + errors: list[SlotValidationError] = field(default_factory=list) + validated_values: dict[str, Any] = field(default_factory=dict) + + +class SlotValidationService: + """ + 槽位校验服务 + + 负责在槽位值写回前执行校验,支持: + - 正则表达式校验 + - JSON Schema 校验 + - 类型校验 + - 必填校验 + """ + + # 支持的槽位类型 + VALID_TYPES = ["string", "number", "boolean", "enum", "array_enum"] + + def __init__(self): + """初始化槽位校验服务""" + self._schema_cache: dict[str, dict] = {} + + def validate_slot_value( + self, + slot_def: dict[str, Any] | SlotDefinition, + value: Any, + tenant_id: str | None = None, + ) -> ValidationResult: + """ + 校验单个槽位值 + + 校验顺序: + 1. 必填校验(如果 required=true 且值为空) + 2. 类型校验 + 3. validation_rule 校验(正则或 JSON Schema) + + Args: + slot_def: 槽位定义(dict 或 SlotDefinition 对象) + value: 待校验的值 + tenant_id: 租户 ID(用于日志记录) + + Returns: + ValidationResult: 校验结果 + """ + # 统一转换为 dict 处理 + if isinstance(slot_def, SlotDefinition): + slot_dict = { + "slot_key": slot_def.slot_key, + "type": slot_def.type, + "required": slot_def.required, + "validation_rule": slot_def.validation_rule, + "ask_back_prompt": slot_def.ask_back_prompt, + } + else: + slot_dict = slot_def + + slot_key = slot_dict.get("slot_key", "unknown") + slot_type = slot_dict.get("type", "string") + required = slot_dict.get("required", False) + validation_rule = slot_dict.get("validation_rule") + ask_back_prompt = slot_dict.get("ask_back_prompt") + + # 1. 必填校验 + if required and self._is_empty_value(value): + logger.info( + f"[SlotValidation] Required slot missing: " + f"tenant_id={tenant_id}, slot_key={slot_key}" + ) + return ValidationResult( + ok=False, + error_code=SlotValidationErrorCode.SLOT_REQUIRED_MISSING, + error_message=f"槽位 '{slot_key}' 为必填项", + ask_back_prompt=ask_back_prompt, + ) + + # 如果值为空且非必填,跳过后续校验 + if self._is_empty_value(value): + return ValidationResult(ok=True, normalized_value=value) + + # 2. 类型校验 + type_result = self._validate_type(slot_dict, value, tenant_id) + if not type_result.ok: + return ValidationResult( + ok=False, + error_code=type_result.error_code, + error_message=type_result.error_message, + ask_back_prompt=ask_back_prompt, + ) + + normalized_value = type_result.normalized_value + + # 3. validation_rule 校验 + if validation_rule and str(validation_rule).strip(): + rule_result = self._validate_rule( + slot_dict, normalized_value, tenant_id + ) + if not rule_result.ok: + return ValidationResult( + ok=False, + error_code=rule_result.error_code, + error_message=rule_result.error_message, + ask_back_prompt=ask_back_prompt, + ) + normalized_value = rule_result.normalized_value or normalized_value + + logger.debug( + f"[SlotValidation] Slot validation passed: " + f"tenant_id={tenant_id}, slot_key={slot_key}, type={slot_type}" + ) + + return ValidationResult(ok=True, normalized_value=normalized_value) + + def validate_slots( + self, + slot_defs: list[dict[str, Any] | SlotDefinition], + values: dict[str, Any], + tenant_id: str | None = None, + ) -> BatchValidationResult: + """ + 批量校验多个槽位值 + + Args: + slot_defs: 槽位定义列表 + values: 槽位值字典 {slot_key: value} + tenant_id: 租户 ID(用于日志记录) + + Returns: + BatchValidationResult: 批量校验结果 + """ + errors: list[SlotValidationError] = [] + validated_values: dict[str, Any] = {} + + # 构建 slot_def 映射 + slot_def_map: dict[str, dict[str, Any] | SlotDefinition] = {} + for slot_def in slot_defs: + if isinstance(slot_def, SlotDefinition): + slot_def_map[slot_def.slot_key] = slot_def + else: + slot_def_map[slot_def.get("slot_key", "")] = slot_def + + # 校验每个提供的值 + for slot_key, value in values.items(): + slot_def = slot_def_map.get(slot_key) + if not slot_def: + # 未定义槽位,跳过校验(允许动态槽位) + validated_values[slot_key] = value + continue + + result = self.validate_slot_value(slot_def, value, tenant_id) + + if result.ok: + validated_values[slot_key] = result.normalized_value + else: + errors.append( + SlotValidationError( + slot_key=slot_key, + error_code=result.error_code or "UNKNOWN_ERROR", + error_message=result.error_message or "校验失败", + ask_back_prompt=result.ask_back_prompt, + ) + ) + + # 检查必填槽位是否缺失 + for slot_def in slot_defs: + if isinstance(slot_def, SlotDefinition): + slot_key = slot_def.slot_key + required = slot_def.required + ask_back_prompt = slot_def.ask_back_prompt + else: + slot_key = slot_def.get("slot_key", "") + required = slot_def.get("required", False) + ask_back_prompt = slot_def.get("ask_back_prompt") + + if required and slot_key not in values: + # 检查是否已经有该错误 + if not any(e.slot_key == slot_key for e in errors): + errors.append( + SlotValidationError( + slot_key=slot_key, + error_code=SlotValidationErrorCode.SLOT_REQUIRED_MISSING, + error_message=f"槽位 '{slot_key}' 为必填项", + ask_back_prompt=ask_back_prompt, + ) + ) + + return BatchValidationResult( + ok=len(errors) == 0, + errors=errors, + validated_values=validated_values, + ) + + def _is_empty_value(self, value: Any) -> bool: + """判断值是否为空""" + if value is None: + return True + if isinstance(value, str) and not value.strip(): + return True + if isinstance(value, list) and len(value) == 0: + return True + return False + + def _validate_type( + self, + slot_def: dict[str, Any], + value: Any, + tenant_id: str | None, + ) -> ValidationResult: + """ + 类型校验 + + Args: + slot_def: 槽位定义字典 + value: 待校验的值 + tenant_id: 租户 ID + + Returns: + ValidationResult: 校验结果 + """ + slot_key = slot_def.get("slot_key", "unknown") + slot_type = slot_def.get("type", "string") + + if slot_type not in self.VALID_TYPES: + logger.warning( + f"[SlotValidation] Unknown slot type: " + f"tenant_id={tenant_id}, slot_key={slot_key}, type={slot_type}" + ) + # 未知类型不阻止,只记录警告 + return ValidationResult(ok=True, normalized_value=value) + + try: + if slot_type == "string": + if not isinstance(value, str): + # 尝试转换为字符串 + normalized = str(value) + return ValidationResult(ok=True, normalized_value=normalized) + return ValidationResult(ok=True, normalized_value=value) + + elif slot_type == "number": + if isinstance(value, bool): + return ValidationResult( + ok=False, + error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID, + error_message=f"槽位 '{slot_key}' 类型应为数字,但得到布尔值", + ) + if isinstance(value, int | float): + return ValidationResult(ok=True, normalized_value=value) + # 尝试转换为数字 + if isinstance(value, str): + try: + if "." in value: + normalized = float(value) + else: + normalized = int(value) + return ValidationResult(ok=True, normalized_value=normalized) + except ValueError: + pass + return ValidationResult( + ok=False, + error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID, + error_message=f"槽位 '{slot_key}' 类型应为数字", + ) + + elif slot_type == "boolean": + if isinstance(value, bool): + return ValidationResult(ok=True, normalized_value=value) + if isinstance(value, str): + lower_val = value.lower() + if lower_val in ("true", "1", "yes", "是", "真"): + return ValidationResult(ok=True, normalized_value=True) + if lower_val in ("false", "0", "no", "否", "假"): + return ValidationResult(ok=True, normalized_value=False) + return ValidationResult( + ok=False, + error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID, + error_message=f"槽位 '{slot_key}' 类型应为布尔值", + ) + + elif slot_type == "enum": + if not isinstance(value, str): + return ValidationResult( + ok=False, + error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID, + error_message=f"槽位 '{slot_key}' 类型应为字符串(枚举)", + ) + # 如果有选项定义,校验值是否在选项中 + options = slot_def.get("options") or [] + if options and value not in options: + return ValidationResult( + ok=False, + error_code=SlotValidationErrorCode.SLOT_ENUM_INVALID, + error_message=f"槽位 '{slot_key}' 的值 '{value}' 不在允许选项 {options} 中", + ) + return ValidationResult(ok=True, normalized_value=value) + + elif slot_type == "array_enum": + if not isinstance(value, list): + return ValidationResult( + ok=False, + error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID, + error_message=f"槽位 '{slot_key}' 类型应为数组", + ) + # 校验数组元素 + options = slot_def.get("options") or [] + for item in value: + if not isinstance(item, str): + return ValidationResult( + ok=False, + error_code=SlotValidationErrorCode.SLOT_ARRAY_ENUM_INVALID, + error_message=f"槽位 '{slot_key}' 的数组元素应为字符串", + ) + if options and item not in options: + return ValidationResult( + ok=False, + error_code=SlotValidationErrorCode.SLOT_ARRAY_ENUM_INVALID, + error_message=f"槽位 '{slot_key}' 的值 '{item}' 不在允许选项 {options} 中", + ) + return ValidationResult(ok=True, normalized_value=value) + + except Exception as e: + logger.error( + f"[SlotValidation] Type validation error: " + f"tenant_id={tenant_id}, slot_key={slot_key}, error={e}" + ) + return ValidationResult( + ok=False, + error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID, + error_message=f"槽位 '{slot_key}' 类型校验异常: {str(e)}", + ) + + return ValidationResult(ok=True, normalized_value=value) + + def _validate_rule( + self, + slot_def: dict[str, Any], + value: Any, + tenant_id: str | None, + ) -> ValidationResult: + """ + 校验规则校验(正则或 JSON Schema) + + 判定逻辑: + 1. 如果规则是 JSON 对象字符串(以 { 或 [ 开头),按 JSON Schema 处理 + 2. 否则按正则表达式处理 + + Args: + slot_def: 槽位定义字典 + value: 待校验的值 + tenant_id: 租户 ID + + Returns: + ValidationResult: 校验结果 + """ + slot_key = slot_def.get("slot_key", "unknown") + validation_rule = str(slot_def.get("validation_rule", "")).strip() + + if not validation_rule: + return ValidationResult(ok=True, normalized_value=value) + + # 判定是 JSON Schema 还是正则表达式 + # JSON Schema 通常以 { 或 [ 开头 + is_json_schema = validation_rule.strip().startswith(("{", "[")) + + if is_json_schema: + return self._validate_json_schema( + slot_key, validation_rule, value, tenant_id + ) + else: + return self._validate_regex( + slot_key, validation_rule, value, tenant_id + ) + + def _validate_regex( + self, + slot_key: str, + pattern: str, + value: Any, + tenant_id: str | None, + ) -> ValidationResult: + """ + 正则表达式校验 + + Args: + slot_key: 槽位键名 + pattern: 正则表达式 + value: 待校验的值 + tenant_id: 租户 ID + + Returns: + ValidationResult: 校验结果 + """ + try: + # 将值转为字符串进行匹配 + str_value = str(value) if value is not None else "" + + if not re.search(pattern, str_value): + logger.info( + f"[SlotValidation] Regex mismatch: " + f"tenant_id={tenant_id}, slot_key={slot_key}" + ) + return ValidationResult( + ok=False, + error_code=SlotValidationErrorCode.SLOT_REGEX_MISMATCH, + error_message=f"槽位 '{slot_key}' 的值不符合格式要求", + ) + + return ValidationResult(ok=True, normalized_value=value) + + except re.error as e: + logger.warning( + f"[SlotValidation] Invalid regex pattern: " + f"tenant_id={tenant_id}, slot_key={slot_key}, pattern={pattern}, error={e}" + ) + return ValidationResult( + ok=False, + error_code=SlotValidationErrorCode.SLOT_VALIDATION_RULE_INVALID, + error_message=f"槽位 '{slot_key}' 的校验规则配置无效(非法正则)", + ) + + def _validate_json_schema( + self, + slot_key: str, + schema_str: str, + value: Any, + tenant_id: str | None, + ) -> ValidationResult: + """ + JSON Schema 校验 + + Args: + slot_key: 槽位键名 + schema_str: JSON Schema 字符串 + value: 待校验的值 + tenant_id: 租户 ID + + Returns: + ValidationResult: 校验结果 + """ + try: + # 解析 JSON Schema + schema = self._schema_cache.get(schema_str) + if schema is None: + schema = json.loads(schema_str) + self._schema_cache[schema_str] = schema + + # 执行校验 + jsonschema.validate(instance=value, schema=schema) + return ValidationResult(ok=True, normalized_value=value) + + except json.JSONDecodeError as e: + logger.warning( + f"[SlotValidation] Invalid JSON schema: " + f"tenant_id={tenant_id}, slot_key={slot_key}, error={e}" + ) + return ValidationResult( + ok=False, + error_code=SlotValidationErrorCode.SLOT_VALIDATION_RULE_INVALID, + error_message=f"槽位 '{slot_key}' 的校验规则配置无效(非法 JSON)", + ) + + except JsonSchemaValidationError as e: + logger.info( + f"[SlotValidation] JSON schema mismatch: " + f"tenant_id={tenant_id}, slot_key={slot_key}, error={e.message}" + ) + return ValidationResult( + ok=False, + error_code=SlotValidationErrorCode.SLOT_JSON_SCHEMA_MISMATCH, + error_message=f"槽位 '{slot_key}' 的值不符合格式要求: {e.message}", + ) + + except Exception as e: + logger.error( + f"[SlotValidation] JSON schema validation error: " + f"tenant_id={tenant_id}, slot_key={slot_key}, error={e}" + ) + return ValidationResult( + ok=False, + error_code=SlotValidationErrorCode.SLOT_VALIDATION_RULE_INVALID, + error_message=f"槽位 '{slot_key}' 的校验规则执行异常: {str(e)}", + ) + + def clear_cache(self) -> None: + """清除 JSON Schema 缓存""" + self._schema_cache.clear() diff --git a/ai-service/app/services/scene_slot_bundle_service.py b/ai-service/app/services/scene_slot_bundle_service.py new file mode 100644 index 0000000..4bad731 --- /dev/null +++ b/ai-service/app/services/scene_slot_bundle_service.py @@ -0,0 +1,463 @@ +""" +Scene Slot Bundle Service. +[AC-SCENE-SLOT-01] 场景-槽位映射配置服务 +[AC-SCENE-SLOT-03] 支持缓存失效 + +职责: +1. 场景槽位包的 CRUD 操作 +2. 槽位引用有效性校验 +3. 运行时场景槽位包查询 +4. 缓存失效管理 +""" + +import logging +import uuid +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.entities import ( + SceneSlotBundle, + SceneSlotBundleCreate, + SceneSlotBundleStatus, + SceneSlotBundleUpdate, +) +from app.services.cache.scene_slot_bundle_cache import get_scene_slot_bundle_cache + +logger = logging.getLogger(__name__) + + +class SceneSlotBundleService: + """ + [AC-SCENE-SLOT-01] 场景槽位包服务 + """ + + def __init__(self, session: AsyncSession): + self._session = session + + async def list_bundles( + self, + tenant_id: str, + status: str | None = None, + ) -> list[SceneSlotBundle]: + """ + 列出场景槽位包 + + Args: + tenant_id: 租户 ID + status: 状态过滤(可选) + + Returns: + 场景槽位包列表 + """ + stmt = select(SceneSlotBundle).where( + SceneSlotBundle.tenant_id == tenant_id + ) + + if status: + stmt = stmt.where(SceneSlotBundle.status == status) + + stmt = stmt.order_by(SceneSlotBundle.created_at.desc()) + + result = await self._session.execute(stmt) + return list(result.scalars().all()) + + async def get_bundle( + self, + tenant_id: str, + bundle_id: str, + ) -> SceneSlotBundle | None: + """ + 获取单个场景槽位包 + + Args: + tenant_id: 租户 ID + bundle_id: 槽位包 ID + + Returns: + 场景槽位包或 None + """ + try: + stmt = select(SceneSlotBundle).where( + SceneSlotBundle.tenant_id == tenant_id, + SceneSlotBundle.id == uuid.UUID(bundle_id), + ) + result = await self._session.execute(stmt) + return result.scalar_one_or_none() + except ValueError: + return None + + async def get_bundle_by_scene_key( + self, + tenant_id: str, + scene_key: str, + ) -> SceneSlotBundle | None: + """ + 根据场景标识获取槽位包 + + Args: + tenant_id: 租户 ID + scene_key: 场景标识 + + Returns: + 场景槽位包或 None + """ + stmt = select(SceneSlotBundle).where( + SceneSlotBundle.tenant_id == tenant_id, + SceneSlotBundle.scene_key == scene_key, + ) + result = await self._session.execute(stmt) + return result.scalar_one_or_none() + + async def get_active_bundle_by_scene( + self, + tenant_id: str, + scene_key: str, + ) -> SceneSlotBundle | None: + """ + 获取活跃状态的场景槽位包(运行时使用) + + Args: + tenant_id: 租户 ID + scene_key: 场景标识 + + Returns: + 活跃状态的场景槽位包或 None + """ + stmt = select(SceneSlotBundle).where( + SceneSlotBundle.tenant_id == tenant_id, + SceneSlotBundle.scene_key == scene_key, + SceneSlotBundle.status == SceneSlotBundleStatus.ACTIVE.value, + ) + result = await self._session.execute(stmt) + return result.scalar_one_or_none() + + async def create_bundle( + self, + tenant_id: str, + bundle_create: SceneSlotBundleCreate, + ) -> SceneSlotBundle: + """ + 创建场景槽位包 + + Args: + tenant_id: 租户 ID + bundle_create: 创建请求数据 + + Returns: + 创建的场景槽位包 + + Raises: + ValueError: 校验失败 + """ + validation_errors = await self._validate_bundle_data( + tenant_id=tenant_id, + required_slots=bundle_create.required_slots, + optional_slots=bundle_create.optional_slots, + slot_priority=bundle_create.slot_priority, + ) + + if validation_errors: + raise ValueError("; ".join(validation_errors)) + + existing = await self.get_bundle_by_scene_key( + tenant_id=tenant_id, + scene_key=bundle_create.scene_key, + ) + + if existing: + raise ValueError(f"场景标识 '{bundle_create.scene_key}' 已存在") + + bundle = SceneSlotBundle( + tenant_id=tenant_id, + scene_key=bundle_create.scene_key, + scene_name=bundle_create.scene_name, + description=bundle_create.description, + required_slots=bundle_create.required_slots or [], + optional_slots=bundle_create.optional_slots or [], + slot_priority=bundle_create.slot_priority, + completion_threshold=bundle_create.completion_threshold, + ask_back_order=bundle_create.ask_back_order, + status=bundle_create.status, + ) + + self._session.add(bundle) + await self._session.flush() + + logger.info( + f"[AC-SCENE-SLOT-01] Created scene slot bundle: " + f"tenant={tenant_id}, scene_key={bundle.scene_key}, " + f"required={len(bundle.required_slots)}, optional={len(bundle.optional_slots)}" + ) + + return bundle + + async def update_bundle( + self, + tenant_id: str, + bundle_id: str, + bundle_update: SceneSlotBundleUpdate, + ) -> SceneSlotBundle | None: + """ + 更新场景槽位包 + + Args: + tenant_id: 租户 ID + bundle_id: 槽位包 ID + bundle_update: 更新请求数据 + + Returns: + 更新后的场景槽位包或 None + + Raises: + ValueError: 校验失败 + """ + bundle = await self.get_bundle(tenant_id, bundle_id) + + if not bundle: + return None + + required_slots = bundle_update.required_slots if bundle_update.required_slots is not None else bundle.required_slots + optional_slots = bundle_update.optional_slots if bundle_update.optional_slots is not None else bundle.optional_slots + slot_priority = bundle_update.slot_priority if bundle_update.slot_priority is not None else bundle.slot_priority + + validation_errors = await self._validate_bundle_data( + tenant_id=tenant_id, + required_slots=required_slots, + optional_slots=optional_slots, + slot_priority=slot_priority, + ) + + if validation_errors: + raise ValueError("; ".join(validation_errors)) + + if bundle_update.scene_name is not None: + bundle.scene_name = bundle_update.scene_name + if bundle_update.description is not None: + bundle.description = bundle_update.description + if bundle_update.required_slots is not None: + bundle.required_slots = bundle_update.required_slots + if bundle_update.optional_slots is not None: + bundle.optional_slots = bundle_update.optional_slots + if bundle_update.slot_priority is not None: + bundle.slot_priority = bundle_update.slot_priority + if bundle_update.completion_threshold is not None: + bundle.completion_threshold = bundle_update.completion_threshold + if bundle_update.ask_back_order is not None: + bundle.ask_back_order = bundle_update.ask_back_order + if bundle_update.status is not None: + bundle.status = bundle_update.status + + bundle.version += 1 + + await self._session.flush() + + # [AC-SCENE-SLOT-03] 使缓存失效 + await self._invalidate_cache(tenant_id, bundle.scene_key) + + logger.info( + f"[AC-SCENE-SLOT-01] Updated scene slot bundle: " + f"tenant={tenant_id}, scene_key={bundle.scene_key}, version={bundle.version}" + ) + + return bundle + + async def delete_bundle( + self, + tenant_id: str, + bundle_id: str, + ) -> bool: + """ + 删除场景槽位包 + + Args: + tenant_id: 租户 ID + bundle_id: 槽位包 ID + + Returns: + 是否成功删除 + """ + bundle = await self.get_bundle(tenant_id, bundle_id) + + if not bundle: + return False + + scene_key = bundle.scene_key + + await self._session.delete(bundle) + await self._session.flush() + + # [AC-SCENE-SLOT-03] 使缓存失效 + await self._invalidate_cache(tenant_id, scene_key) + + logger.info( + f"[AC-SCENE-SLOT-01] Deleted scene slot bundle: " + f"tenant={tenant_id}, scene_key={scene_key}" + ) + + return True + + async def _validate_bundle_data( + self, + tenant_id: str, + required_slots: list[str], + optional_slots: list[str], + slot_priority: list[str] | None, + ) -> list[str]: + """ + 校验槽位包数据 + + Args: + tenant_id: 租户 ID + required_slots: 必填槽位列表 + optional_slots: 可选槽位列表 + slot_priority: 优先级列表 + + Returns: + 错误信息列表(空列表表示校验通过) + """ + errors = [] + + required_set = set(required_slots or []) + optional_set = set(optional_slots or []) + + overlap = required_set & optional_set + if overlap: + errors.append(f"必填和可选槽位存在交叉: {list(overlap)}") + + if slot_priority: + priority_set = set(slot_priority) + all_slots = required_set | optional_set + + unknown_in_priority = priority_set - all_slots + if unknown_in_priority: + errors.append(f"优先级列表包含未定义的槽位: {list(unknown_in_priority)}") + + valid_slots = await self._validate_slot_keys( + tenant_id=tenant_id, + slot_keys=list(required_set | optional_set), + ) + + invalid_slots = (required_set | optional_set) - valid_slots + if invalid_slots: + errors.append(f"以下槽位不存在: {list(invalid_slots)}") + + return errors + + async def _validate_slot_keys( + self, + tenant_id: str, + slot_keys: list[str], + ) -> set[str]: + """ + 校验槽位键名是否存在 + + Args: + tenant_id: 租户 ID + slot_keys: 待校验的槽位键名列表 + + Returns: + 有效的槽位键名集合 + """ + if not slot_keys: + return set() + + from app.services.slot_definition_service import SlotDefinitionService + + slot_service = SlotDefinitionService(self._session) + all_slots = await slot_service.list_slot_definitions(tenant_id) + + valid_keys = {slot.slot_key for slot in all_slots} + + return set(slot_keys) & valid_keys + + async def get_bundle_with_slot_details( + self, + tenant_id: str, + bundle_id: str, + ) -> dict[str, Any] | None: + """ + 获取场景槽位包及其槽位详情 + + Args: + tenant_id: 租户 ID + bundle_id: 槽位包 ID + + Returns: + 包含槽位详情的字典或 None + """ + bundle = await self.get_bundle(tenant_id, bundle_id) + + if not bundle: + return None + + from app.services.slot_definition_service import SlotDefinitionService + + slot_service = SlotDefinitionService(self._session) + all_slots = await slot_service.list_slot_definitions(tenant_id) + + slot_map = {slot.slot_key: slot for slot in all_slots} + + required_slot_details = [] + for slot_key in bundle.required_slots: + if slot_key in slot_map: + slot = slot_map[slot_key] + required_slot_details.append({ + "slot_key": slot.slot_key, + "type": slot.type, + "required": slot.required, + "ask_back_prompt": slot.ask_back_prompt, + "linked_field_id": str(slot.linked_field_id) if slot.linked_field_id else None, + }) + + optional_slot_details = [] + for slot_key in bundle.optional_slots: + if slot_key in slot_map: + slot = slot_map[slot_key] + optional_slot_details.append({ + "slot_key": slot.slot_key, + "type": slot.type, + "required": slot.required, + "ask_back_prompt": slot.ask_back_prompt, + "linked_field_id": str(slot.linked_field_id) if slot.linked_field_id else None, + }) + + return { + "id": str(bundle.id), + "tenant_id": str(bundle.tenant_id), + "scene_key": bundle.scene_key, + "scene_name": bundle.scene_name, + "description": bundle.description, + "required_slots": bundle.required_slots, + "optional_slots": bundle.optional_slots, + "required_slot_details": required_slot_details, + "optional_slot_details": optional_slot_details, + "slot_priority": bundle.slot_priority, + "completion_threshold": bundle.completion_threshold, + "ask_back_order": bundle.ask_back_order, + "status": bundle.status, + "version": bundle.version, + "created_at": bundle.created_at.isoformat() if bundle.created_at else None, + "updated_at": bundle.updated_at.isoformat() if bundle.updated_at else None, + } + + async def _invalidate_cache( + self, + tenant_id: str, + scene_key: str, + ) -> None: + """ + [AC-SCENE-SLOT-03] 使缓存失效 + + Args: + tenant_id: 租户 ID + scene_key: 场景标识 + """ + try: + cache = get_scene_slot_bundle_cache() + await cache.invalidate_on_update(tenant_id, scene_key) + except Exception as e: + logger.warning( + f"[AC-SCENE-SLOT-03] Failed to invalidate cache: {e}" + )