feat: add slot management system with validation, backfill, state aggregation and scene bundle support [AC-SLOT-MGMT]
This commit is contained in:
parent
248a225436
commit
9769f7ccf0
|
|
@ -0,0 +1,265 @@
|
|||
"""
|
||||
Scene Slot Bundle API.
|
||||
[AC-SCENE-SLOT-01] 场景-槽位映射配置管理接口
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Annotated, Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from fastapi.responses import JSONResponse
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.database import get_session
|
||||
from app.core.exceptions import MissingTenantIdException
|
||||
from app.core.tenant import get_tenant_id
|
||||
from app.models.entities import SceneSlotBundleCreate, SceneSlotBundleUpdate
|
||||
from app.services.scene_slot_bundle_service import SceneSlotBundleService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/admin/scene-slot-bundles", tags=["SceneSlotBundle"])
|
||||
|
||||
|
||||
def get_current_tenant_id() -> str:
|
||||
"""Get current tenant ID from context."""
|
||||
tenant_id = get_tenant_id()
|
||||
if not tenant_id:
|
||||
raise MissingTenantIdException()
|
||||
return tenant_id
|
||||
|
||||
|
||||
def _bundle_to_dict(bundle: Any) -> dict[str, Any]:
|
||||
"""Convert bundle to dict"""
|
||||
return {
|
||||
"id": str(bundle.id),
|
||||
"tenant_id": str(bundle.tenant_id),
|
||||
"scene_key": bundle.scene_key,
|
||||
"scene_name": bundle.scene_name,
|
||||
"description": bundle.description,
|
||||
"required_slots": bundle.required_slots,
|
||||
"optional_slots": bundle.optional_slots,
|
||||
"slot_priority": bundle.slot_priority,
|
||||
"completion_threshold": bundle.completion_threshold,
|
||||
"ask_back_order": bundle.ask_back_order,
|
||||
"status": bundle.status,
|
||||
"version": bundle.version,
|
||||
"created_at": bundle.created_at.isoformat() if bundle.created_at else None,
|
||||
"updated_at": bundle.updated_at.isoformat() if bundle.updated_at else None,
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"",
|
||||
operation_id="listSceneSlotBundles",
|
||||
summary="List scene slot bundles",
|
||||
description="获取场景槽位包列表",
|
||||
)
|
||||
async def list_scene_slot_bundles(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
status: Annotated[str | None, Query(
|
||||
description="按状态过滤: draft/active/deprecated"
|
||||
)] = None,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
列出场景槽位包
|
||||
"""
|
||||
logger.info(
|
||||
f"Listing scene slot bundles: tenant={tenant_id}, status={status}"
|
||||
)
|
||||
|
||||
service = SceneSlotBundleService(session)
|
||||
bundles = await service.list_bundles(tenant_id, status)
|
||||
|
||||
return JSONResponse(
|
||||
content=[_bundle_to_dict(b) for b in bundles]
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"",
|
||||
operation_id="createSceneSlotBundle",
|
||||
summary="Create scene slot bundle",
|
||||
description="[AC-SCENE-SLOT-01] 创建新的场景槽位包",
|
||||
status_code=201,
|
||||
)
|
||||
async def create_scene_slot_bundle(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
bundle_create: SceneSlotBundleCreate,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-SCENE-SLOT-01] 创建场景槽位包
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-SCENE-SLOT-01] Creating scene slot bundle: "
|
||||
f"tenant={tenant_id}, scene_key={bundle_create.scene_key}"
|
||||
)
|
||||
|
||||
service = SceneSlotBundleService(session)
|
||||
|
||||
try:
|
||||
bundle = await service.create_bundle(tenant_id, bundle_create)
|
||||
await session.commit()
|
||||
except ValueError as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"error_code": "VALIDATION_ERROR",
|
||||
"message": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(
|
||||
status_code=201,
|
||||
content=_bundle_to_dict(bundle)
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/by-scene/{scene_key}",
|
||||
operation_id="getSceneSlotBundleBySceneKey",
|
||||
summary="Get scene slot bundle by scene key",
|
||||
description="根据场景标识获取槽位包",
|
||||
)
|
||||
async def get_scene_slot_bundle_by_scene_key(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
scene_key: str,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
根据场景标识获取槽位包
|
||||
"""
|
||||
logger.info(
|
||||
f"Getting scene slot bundle by scene_key: tenant={tenant_id}, scene_key={scene_key}"
|
||||
)
|
||||
|
||||
service = SceneSlotBundleService(session)
|
||||
bundle = await service.get_bundle_by_scene_key(tenant_id, scene_key)
|
||||
|
||||
if not bundle:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"error_code": "NOT_FOUND",
|
||||
"message": f"Scene slot bundle with scene_key '{scene_key}' not found",
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(content=_bundle_to_dict(bundle))
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{id}",
|
||||
operation_id="getSceneSlotBundle",
|
||||
summary="Get scene slot bundle by ID",
|
||||
description="获取单个场景槽位包详情(含槽位详情)",
|
||||
)
|
||||
async def get_scene_slot_bundle(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
id: str,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
获取单个场景槽位包详情
|
||||
"""
|
||||
logger.info(
|
||||
f"Getting scene slot bundle: tenant={tenant_id}, id={id}"
|
||||
)
|
||||
|
||||
service = SceneSlotBundleService(session)
|
||||
bundle = await service.get_bundle_with_slot_details(tenant_id, id)
|
||||
|
||||
if not bundle:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"error_code": "NOT_FOUND",
|
||||
"message": f"Scene slot bundle {id} not found",
|
||||
}
|
||||
)
|
||||
|
||||
return JSONResponse(content=bundle)
|
||||
|
||||
|
||||
@router.put(
|
||||
"/{id}",
|
||||
operation_id="updateSceneSlotBundle",
|
||||
summary="Update scene slot bundle",
|
||||
description="更新场景槽位包",
|
||||
)
|
||||
async def update_scene_slot_bundle(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
id: str,
|
||||
bundle_update: SceneSlotBundleUpdate,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
更新场景槽位包
|
||||
"""
|
||||
logger.info(
|
||||
f"Updating scene slot bundle: tenant={tenant_id}, id={id}"
|
||||
)
|
||||
|
||||
service = SceneSlotBundleService(session)
|
||||
|
||||
try:
|
||||
bundle = await service.update_bundle(tenant_id, id, bundle_update)
|
||||
except ValueError as e:
|
||||
return JSONResponse(
|
||||
status_code=400,
|
||||
content={
|
||||
"error_code": "VALIDATION_ERROR",
|
||||
"message": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
if not bundle:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"error_code": "NOT_FOUND",
|
||||
"message": f"Scene slot bundle {id} not found",
|
||||
}
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
return JSONResponse(content=_bundle_to_dict(bundle))
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/{id}",
|
||||
operation_id="deleteSceneSlotBundle",
|
||||
summary="Delete scene slot bundle",
|
||||
description="[AC-SCENE-SLOT-01] 删除场景槽位包",
|
||||
status_code=204,
|
||||
)
|
||||
async def delete_scene_slot_bundle(
|
||||
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
|
||||
session: Annotated[AsyncSession, Depends(get_session)],
|
||||
id: str,
|
||||
) -> JSONResponse:
|
||||
"""
|
||||
[AC-SCENE-SLOT-01] 删除场景槽位包
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-SCENE-SLOT-01] Deleting scene slot bundle: tenant={tenant_id}, id={id}"
|
||||
)
|
||||
|
||||
service = SceneSlotBundleService(session)
|
||||
success = await service.delete_bundle(tenant_id, id)
|
||||
|
||||
if not success:
|
||||
return JSONResponse(
|
||||
status_code=404,
|
||||
content={
|
||||
"error_code": "NOT_FOUND",
|
||||
"message": f"Scene slot bundle not found: {id}",
|
||||
}
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
return JSONResponse(status_code=204, content=None)
|
||||
|
|
@ -38,9 +38,13 @@ def _slot_to_dict(slot: dict[str, Any] | Any) -> dict[str, Any]:
|
|||
"id": str(slot.id),
|
||||
"tenant_id": str(slot.tenant_id),
|
||||
"slot_key": slot.slot_key,
|
||||
"display_name": slot.display_name,
|
||||
"description": slot.description,
|
||||
"type": slot.type,
|
||||
"required": slot.required,
|
||||
# [AC-MRS-07-UPGRADE] 返回新旧字段
|
||||
"extract_strategy": slot.extract_strategy,
|
||||
"extract_strategies": slot.extract_strategies,
|
||||
"validation_rule": slot.validation_rule,
|
||||
"ask_back_prompt": slot.ask_back_prompt,
|
||||
"default_value": slot.default_value,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,274 @@
|
|||
"""
|
||||
Scene Slot Bundle Cache Service.
|
||||
[AC-SCENE-SLOT-03] 场景槽位包缓存服务
|
||||
|
||||
职责:
|
||||
1. 缓存场景槽位包配置,减少数据库查询
|
||||
2. 支持缓存失效和刷新
|
||||
3. 支持租户隔离
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
try:
|
||||
import redis.asyncio as redis
|
||||
REDIS_AVAILABLE = True
|
||||
except ImportError:
|
||||
REDIS_AVAILABLE = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedSceneSlotBundle:
|
||||
"""缓存的场景槽位包"""
|
||||
scene_key: str
|
||||
scene_name: str
|
||||
description: str | None
|
||||
required_slots: list[str]
|
||||
optional_slots: list[str]
|
||||
slot_priority: list[str] | None
|
||||
completion_threshold: float
|
||||
ask_back_order: str
|
||||
status: str
|
||||
version: int
|
||||
cached_at: datetime = field(default_factory=datetime.utcnow)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"scene_key": self.scene_key,
|
||||
"scene_name": self.scene_name,
|
||||
"description": self.description,
|
||||
"required_slots": self.required_slots,
|
||||
"optional_slots": self.optional_slots,
|
||||
"slot_priority": self.slot_priority,
|
||||
"completion_threshold": self.completion_threshold,
|
||||
"ask_back_order": self.ask_back_order,
|
||||
"status": self.status,
|
||||
"version": self.version,
|
||||
"cached_at": self.cached_at.isoformat(),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "CachedSceneSlotBundle":
|
||||
return cls(
|
||||
scene_key=data["scene_key"],
|
||||
scene_name=data["scene_name"],
|
||||
description=data.get("description"),
|
||||
required_slots=data.get("required_slots", []),
|
||||
optional_slots=data.get("optional_slots", []),
|
||||
slot_priority=data.get("slot_priority"),
|
||||
completion_threshold=data.get("completion_threshold", 1.0),
|
||||
ask_back_order=data.get("ask_back_order", "priority"),
|
||||
status=data.get("status", "draft"),
|
||||
version=data.get("version", 1),
|
||||
cached_at=datetime.fromisoformat(data["cached_at"]) if data.get("cached_at") else datetime.utcnow(),
|
||||
)
|
||||
|
||||
|
||||
class SceneSlotBundleCache:
|
||||
"""
|
||||
[AC-SCENE-SLOT-03] 场景槽位包缓存
|
||||
|
||||
使用 Redis 或内存缓存场景槽位包配置
|
||||
"""
|
||||
|
||||
CACHE_PREFIX = "scene_slot_bundle"
|
||||
CACHE_TTL_SECONDS = 300 # 5分钟缓存
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Any = None,
|
||||
ttl_seconds: int = 300,
|
||||
):
|
||||
self._redis = redis_client
|
||||
self._ttl = ttl_seconds or self.CACHE_TTL_SECONDS
|
||||
self._memory_cache: dict[str, tuple[CachedSceneSlotBundle, datetime]] = {}
|
||||
|
||||
def _get_cache_key(self, tenant_id: str, scene_key: str) -> str:
|
||||
"""生成缓存键"""
|
||||
return f"{self.CACHE_PREFIX}:{tenant_id}:{scene_key}"
|
||||
|
||||
async def get(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
) -> CachedSceneSlotBundle | None:
|
||||
"""
|
||||
获取缓存的场景槽位包
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
|
||||
Returns:
|
||||
缓存的场景槽位包或 None
|
||||
"""
|
||||
cache_key = self._get_cache_key(tenant_id, scene_key)
|
||||
|
||||
if self._redis and REDIS_AVAILABLE:
|
||||
try:
|
||||
cached_data = await self._redis.get(cache_key)
|
||||
if cached_data:
|
||||
data = json.loads(cached_data)
|
||||
return CachedSceneSlotBundle.from_dict(data)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AC-SCENE-SLOT-03] Redis get failed: {e}, falling back to memory cache"
|
||||
)
|
||||
|
||||
if cache_key in self._memory_cache:
|
||||
cached_bundle, cached_at = self._memory_cache[cache_key]
|
||||
if datetime.utcnow() - cached_at < timedelta(seconds=self._ttl):
|
||||
return cached_bundle
|
||||
else:
|
||||
del self._memory_cache[cache_key]
|
||||
|
||||
return None
|
||||
|
||||
async def set(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
bundle: CachedSceneSlotBundle,
|
||||
) -> bool:
|
||||
"""
|
||||
设置缓存
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
bundle: 要缓存的场景槽位包
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
cache_key = self._get_cache_key(tenant_id, scene_key)
|
||||
|
||||
if self._redis and REDIS_AVAILABLE:
|
||||
try:
|
||||
await self._redis.setex(
|
||||
cache_key,
|
||||
self._ttl,
|
||||
json.dumps(bundle.to_dict()),
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AC-SCENE-SLOT-03] Redis set failed: {e}, falling back to memory cache"
|
||||
)
|
||||
|
||||
self._memory_cache[cache_key] = (bundle, datetime.utcnow())
|
||||
return True
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
) -> bool:
|
||||
"""
|
||||
删除缓存
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
cache_key = self._get_cache_key(tenant_id, scene_key)
|
||||
|
||||
if self._redis and REDIS_AVAILABLE:
|
||||
try:
|
||||
await self._redis.delete(cache_key)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AC-SCENE-SLOT-03] Redis delete failed: {e}"
|
||||
)
|
||||
|
||||
if cache_key in self._memory_cache:
|
||||
del self._memory_cache[cache_key]
|
||||
|
||||
return True
|
||||
|
||||
async def delete_by_tenant(self, tenant_id: str) -> bool:
|
||||
"""
|
||||
删除租户下所有缓存
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
pattern = f"{self.CACHE_PREFIX}:{tenant_id}:*"
|
||||
|
||||
if self._redis and REDIS_AVAILABLE:
|
||||
try:
|
||||
keys = []
|
||||
async for key in self._redis.scan_iter(match=pattern):
|
||||
keys.append(key)
|
||||
if keys:
|
||||
await self._redis.delete(*keys)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AC-SCENE-SLOT-03] Redis delete by tenant failed: {e}"
|
||||
)
|
||||
|
||||
keys_to_delete = [
|
||||
k for k in self._memory_cache
|
||||
if k.startswith(f"{self.CACHE_PREFIX}:{tenant_id}:")
|
||||
]
|
||||
for key in keys_to_delete:
|
||||
del self._memory_cache[key]
|
||||
|
||||
return True
|
||||
|
||||
async def invalidate_on_update(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
) -> bool:
|
||||
"""
|
||||
当场景槽位包更新时使缓存失效
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
logger.info(
|
||||
f"[AC-SCENE-SLOT-03] Invalidating cache for scene slot bundle: "
|
||||
f"tenant={tenant_id}, scene={scene_key}"
|
||||
)
|
||||
return await self.delete(tenant_id, scene_key)
|
||||
|
||||
|
||||
_scene_slot_bundle_cache: SceneSlotBundleCache | None = None
|
||||
|
||||
|
||||
def get_scene_slot_bundle_cache() -> SceneSlotBundleCache:
|
||||
"""获取场景槽位包缓存实例"""
|
||||
global _scene_slot_bundle_cache
|
||||
if _scene_slot_bundle_cache is None:
|
||||
_scene_slot_bundle_cache = SceneSlotBundleCache()
|
||||
return _scene_slot_bundle_cache
|
||||
|
||||
|
||||
def init_scene_slot_bundle_cache(redis_client: Any = None, ttl_seconds: int = 300) -> None:
|
||||
"""初始化场景槽位包缓存"""
|
||||
global _scene_slot_bundle_cache
|
||||
_scene_slot_bundle_cache = SceneSlotBundleCache(
|
||||
redis_client=redis_client,
|
||||
ttl_seconds=ttl_seconds,
|
||||
)
|
||||
logger.info(
|
||||
f"[AC-SCENE-SLOT-03] Scene slot bundle cache initialized: "
|
||||
f"ttl={ttl_seconds}s, redis={redis_client is not None}"
|
||||
)
|
||||
|
|
@ -0,0 +1,397 @@
|
|||
"""
|
||||
Slot State Cache Layer.
|
||||
槽位状态缓存层 - 提供会话级槽位状态持久化
|
||||
|
||||
[AC-MRS-SLOT-CACHE-01] 多轮状态持久化
|
||||
|
||||
Features:
|
||||
- L1: In-memory cache (process-level, 5 min TTL)
|
||||
- L2: Redis cache (shared, configurable TTL)
|
||||
- Automatic fallback on cache miss
|
||||
- Support for slot value source priority
|
||||
|
||||
Key format: slot_state:{tenant_id}:{session_id}
|
||||
TTL: Configurable (default 30 minutes)
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import redis.asyncio as redis
|
||||
|
||||
from app.core.config import get_settings
|
||||
from app.models.mid.schemas import SlotSource
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedSlotValue:
|
||||
"""
|
||||
缓存的槽位值
|
||||
|
||||
Attributes:
|
||||
value: 槽位值
|
||||
source: 值来源 (user_confirmed, rule_extracted, llm_inferred, default, context)
|
||||
confidence: 置信度
|
||||
updated_at: 更新时间戳
|
||||
"""
|
||||
value: Any
|
||||
source: str
|
||||
confidence: float = 1.0
|
||||
updated_at: float = field(default_factory=time.time)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"value": self.value,
|
||||
"source": self.source,
|
||||
"confidence": self.confidence,
|
||||
"updated_at": self.updated_at,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "CachedSlotValue":
|
||||
return cls(
|
||||
value=data["value"],
|
||||
source=data["source"],
|
||||
confidence=data.get("confidence", 1.0),
|
||||
updated_at=data.get("updated_at", time.time()),
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class CachedSlotState:
|
||||
"""
|
||||
缓存的槽位状态
|
||||
|
||||
Attributes:
|
||||
filled_slots: 已填充的槽位值字典 {slot_key: CachedSlotValue}
|
||||
slot_to_field_map: 槽位到元数据字段的映射
|
||||
created_at: 创建时间
|
||||
updated_at: 最后更新时间
|
||||
"""
|
||||
filled_slots: dict[str, CachedSlotValue] = field(default_factory=dict)
|
||||
slot_to_field_map: dict[str, str] = field(default_factory=dict)
|
||||
created_at: float = field(default_factory=time.time)
|
||||
updated_at: float = field(default_factory=time.time)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"filled_slots": {
|
||||
k: v.to_dict() for k, v in self.filled_slots.items()
|
||||
},
|
||||
"slot_to_field_map": self.slot_to_field_map,
|
||||
"created_at": self.created_at,
|
||||
"updated_at": self.updated_at,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> "CachedSlotState":
|
||||
filled_slots = {}
|
||||
for k, v in data.get("filled_slots", {}).items():
|
||||
if isinstance(v, dict):
|
||||
filled_slots[k] = CachedSlotValue.from_dict(v)
|
||||
else:
|
||||
filled_slots[k] = CachedSlotValue(value=v, source="unknown")
|
||||
|
||||
return cls(
|
||||
filled_slots=filled_slots,
|
||||
slot_to_field_map=data.get("slot_to_field_map", {}),
|
||||
created_at=data.get("created_at", time.time()),
|
||||
updated_at=data.get("updated_at", time.time()),
|
||||
)
|
||||
|
||||
def get_simple_filled_slots(self) -> dict[str, Any]:
|
||||
"""获取简化的已填充槽位字典(仅值)"""
|
||||
return {k: v.value for k, v in self.filled_slots.items()}
|
||||
|
||||
def get_slot_sources(self) -> dict[str, str]:
|
||||
"""获取槽位来源字典"""
|
||||
return {k: v.source for k, v in self.filled_slots.items()}
|
||||
|
||||
def get_slot_confidence(self) -> dict[str, float]:
|
||||
"""获取槽位置信度字典"""
|
||||
return {k: v.confidence for k, v in self.filled_slots.items()}
|
||||
|
||||
|
||||
class SlotStateCache:
|
||||
"""
|
||||
[AC-MRS-SLOT-CACHE-01] 槽位状态缓存层
|
||||
|
||||
提供会话级槽位状态持久化,支持:
|
||||
- L1: 内存缓存(进程级,5分钟 TTL)
|
||||
- L2: Redis 缓存(共享,可配置 TTL)
|
||||
- 自动降级(Redis 不可用时仅使用内存缓存)
|
||||
- 槽位值来源优先级合并
|
||||
|
||||
Key format: slot_state:{tenant_id}:{session_id}
|
||||
TTL: Configurable via settings.slot_state_cache_ttl (default 1800 seconds = 30 minutes)
|
||||
"""
|
||||
|
||||
_local_cache: dict[str, tuple[CachedSlotState, float]] = {}
|
||||
_local_cache_ttl = 300
|
||||
|
||||
SOURCE_PRIORITY = {
|
||||
SlotSource.USER_CONFIRMED.value: 100,
|
||||
"user_confirmed": 100,
|
||||
SlotSource.RULE_EXTRACTED.value: 80,
|
||||
"rule_extracted": 80,
|
||||
SlotSource.LLM_INFERRED.value: 60,
|
||||
"llm_inferred": 60,
|
||||
"context": 40,
|
||||
SlotSource.DEFAULT.value: 20,
|
||||
"default": 20,
|
||||
"unknown": 0,
|
||||
}
|
||||
|
||||
def __init__(self, redis_client: redis.Redis | None = None):
|
||||
self._redis = redis_client
|
||||
self._settings = get_settings()
|
||||
self._enabled = self._settings.redis_enabled
|
||||
self._cache_ttl = getattr(self._settings, "slot_state_cache_ttl", 1800)
|
||||
|
||||
async def _get_client(self) -> redis.Redis | None:
|
||||
"""Get or create Redis client."""
|
||||
if not self._enabled:
|
||||
return None
|
||||
if self._redis is None:
|
||||
try:
|
||||
self._redis = redis.from_url(
|
||||
self._settings.redis_url,
|
||||
encoding="utf-8",
|
||||
decode_responses=True,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[SlotStateCache] Failed to connect to Redis: {e}")
|
||||
self._enabled = False
|
||||
return None
|
||||
return self._redis
|
||||
|
||||
def _make_key(self, tenant_id: str, session_id: str) -> str:
|
||||
"""Generate cache key."""
|
||||
return f"slot_state:{tenant_id}:{session_id}"
|
||||
|
||||
def _make_local_key(self, tenant_id: str, session_id: str) -> str:
|
||||
"""Generate local cache key."""
|
||||
return f"{tenant_id}:{session_id}"
|
||||
|
||||
def _get_source_priority(self, source: str) -> int:
|
||||
"""Get priority for a source."""
|
||||
return self.SOURCE_PRIORITY.get(source, 0)
|
||||
|
||||
async def get(
|
||||
self,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
) -> CachedSlotState | None:
|
||||
"""
|
||||
Get cached slot state (L1 -> L2).
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID for isolation
|
||||
session_id: Session ID
|
||||
|
||||
Returns:
|
||||
CachedSlotState or None if not found
|
||||
"""
|
||||
local_key = self._make_local_key(tenant_id, session_id)
|
||||
if local_key in self._local_cache:
|
||||
state, timestamp = self._local_cache[local_key]
|
||||
if time.time() - timestamp < self._local_cache_ttl:
|
||||
logger.debug(f"[SlotStateCache] L1 hit: {local_key}")
|
||||
return state
|
||||
else:
|
||||
del self._local_cache[local_key]
|
||||
|
||||
client = await self._get_client()
|
||||
if client is None:
|
||||
return None
|
||||
|
||||
key = self._make_key(tenant_id, session_id)
|
||||
|
||||
try:
|
||||
data = await client.get(key)
|
||||
if data:
|
||||
logger.debug(f"[SlotStateCache] L2 hit: {key}")
|
||||
state_dict = json.loads(data)
|
||||
state = CachedSlotState.from_dict(state_dict)
|
||||
|
||||
self._local_cache[local_key] = (state, time.time())
|
||||
|
||||
return state
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.warning(f"[SlotStateCache] Failed to get from cache: {e}")
|
||||
return None
|
||||
|
||||
async def set(
|
||||
self,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
state: CachedSlotState,
|
||||
) -> bool:
|
||||
"""
|
||||
Set slot state to cache (L1 + L2).
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID for isolation
|
||||
session_id: Session ID
|
||||
state: CachedSlotState to cache
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
local_key = self._make_local_key(tenant_id, session_id)
|
||||
state.updated_at = time.time()
|
||||
self._local_cache[local_key] = (state, time.time())
|
||||
|
||||
client = await self._get_client()
|
||||
if client is None:
|
||||
return False
|
||||
|
||||
key = self._make_key(tenant_id, session_id)
|
||||
|
||||
try:
|
||||
await client.setex(
|
||||
key,
|
||||
self._cache_ttl,
|
||||
json.dumps(state.to_dict(), default=str),
|
||||
)
|
||||
logger.debug(f"[SlotStateCache] Set cache: {key}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"[SlotStateCache] Failed to set cache: {e}")
|
||||
return False
|
||||
|
||||
async def merge_and_set(
|
||||
self,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
new_slots: dict[str, CachedSlotValue],
|
||||
slot_to_field_map: dict[str, str] | None = None,
|
||||
) -> CachedSlotState:
|
||||
"""
|
||||
Merge new slot values with cached state and save.
|
||||
|
||||
Priority: user_confirmed > rule_extracted > llm_inferred > context > default
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
session_id: Session ID
|
||||
new_slots: New slot values to merge
|
||||
slot_to_field_map: Slot to field mapping
|
||||
|
||||
Returns:
|
||||
Updated CachedSlotState
|
||||
"""
|
||||
state = await self.get(tenant_id, session_id)
|
||||
if state is None:
|
||||
state = CachedSlotState()
|
||||
|
||||
for slot_key, new_value in new_slots.items():
|
||||
if slot_key in state.filled_slots:
|
||||
existing = state.filled_slots[slot_key]
|
||||
existing_priority = self._get_source_priority(existing.source)
|
||||
new_priority = self._get_source_priority(new_value.source)
|
||||
|
||||
if new_priority >= existing_priority:
|
||||
state.filled_slots[slot_key] = new_value
|
||||
logger.debug(
|
||||
f"[SlotStateCache] Slot '{slot_key}' updated: "
|
||||
f"{existing.source}({existing_priority}) -> "
|
||||
f"{new_value.source}({new_priority})"
|
||||
)
|
||||
else:
|
||||
state.filled_slots[slot_key] = new_value
|
||||
logger.debug(
|
||||
f"[SlotStateCache] Slot '{slot_key}' added: "
|
||||
f"source={new_value.source}, value={new_value.value}"
|
||||
)
|
||||
|
||||
if slot_to_field_map:
|
||||
state.slot_to_field_map.update(slot_to_field_map)
|
||||
|
||||
await self.set(tenant_id, session_id, state)
|
||||
|
||||
return state
|
||||
|
||||
async def delete(
|
||||
self,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete slot state from cache (L1 + L2).
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID for isolation
|
||||
session_id: Session ID
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
local_key = self._make_local_key(tenant_id, session_id)
|
||||
if local_key in self._local_cache:
|
||||
del self._local_cache[local_key]
|
||||
|
||||
client = await self._get_client()
|
||||
if client is None:
|
||||
return False
|
||||
|
||||
key = self._make_key(tenant_id, session_id)
|
||||
|
||||
try:
|
||||
await client.delete(key)
|
||||
logger.debug(f"[SlotStateCache] Deleted cache: {key}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"[SlotStateCache] Failed to delete cache: {e}")
|
||||
return False
|
||||
|
||||
async def clear_slot(
|
||||
self,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
slot_key: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Clear a specific slot from cached state.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant ID
|
||||
session_id: Session ID
|
||||
slot_key: Slot key to clear
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
"""
|
||||
state = await self.get(tenant_id, session_id)
|
||||
if state is None:
|
||||
return True
|
||||
|
||||
if slot_key in state.filled_slots:
|
||||
del state.filled_slots[slot_key]
|
||||
await self.set(tenant_id, session_id, state)
|
||||
logger.debug(f"[SlotStateCache] Cleared slot: {slot_key}")
|
||||
|
||||
return True
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close Redis connection."""
|
||||
if self._redis:
|
||||
await self._redis.close()
|
||||
|
||||
|
||||
_slot_state_cache: SlotStateCache | None = None
|
||||
|
||||
|
||||
def get_slot_state_cache() -> SlotStateCache:
|
||||
"""Get singleton SlotStateCache instance."""
|
||||
global _slot_state_cache
|
||||
if _slot_state_cache is None:
|
||||
_slot_state_cache = SlotStateCache()
|
||||
return _slot_state_cache
|
||||
|
|
@ -0,0 +1,202 @@
|
|||
"""
|
||||
元数据字段定义缓存服务
|
||||
使用 Redis 缓存 metadata_field_definitions,减少数据库查询
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.core.config import get_settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MetadataCacheService:
|
||||
"""
|
||||
元数据字段定义缓存服务
|
||||
|
||||
缓存策略:
|
||||
- Key: metadata:fields:{tenant_id}
|
||||
- Value: JSON 序列化的字段定义列表
|
||||
- TTL: 1小时(3600秒)
|
||||
- 更新策略:写时更新 + 定时刷新
|
||||
"""
|
||||
|
||||
CACHE_KEY_PREFIX = "metadata:fields"
|
||||
DEFAULT_TTL = 3600 # 1小时
|
||||
|
||||
def __init__(self):
|
||||
self._settings = get_settings()
|
||||
self._redis_client = None
|
||||
self._enabled = self._settings.redis_enabled
|
||||
|
||||
async def _get_redis(self):
|
||||
"""获取 Redis 连接(延迟初始化)"""
|
||||
if not self._enabled:
|
||||
return None
|
||||
|
||||
if self._redis_client is None:
|
||||
try:
|
||||
import redis.asyncio as redis
|
||||
self._redis_client = redis.from_url(
|
||||
self._settings.redis_url,
|
||||
decode_responses=True
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"[MetadataCache] Failed to connect Redis: {e}")
|
||||
self._enabled = False
|
||||
return None
|
||||
|
||||
return self._redis_client
|
||||
|
||||
def _make_key(self, tenant_id: str) -> str:
|
||||
"""生成缓存 key"""
|
||||
return f"{self.CACHE_KEY_PREFIX}:{tenant_id}"
|
||||
|
||||
async def get_fields(self, tenant_id: str) -> list[dict[str, Any]] | None:
|
||||
"""
|
||||
获取缓存的字段定义
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
|
||||
Returns:
|
||||
字段定义列表,未缓存返回 None
|
||||
"""
|
||||
if not self._enabled:
|
||||
return None
|
||||
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
if not redis:
|
||||
return None
|
||||
|
||||
key = self._make_key(tenant_id)
|
||||
cached_data = await redis.get(key)
|
||||
|
||||
if cached_data:
|
||||
logger.info(f"[MetadataCache] Cache hit for tenant={tenant_id}")
|
||||
return json.loads(cached_data)
|
||||
|
||||
logger.info(f"[MetadataCache] Cache miss for tenant={tenant_id}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[MetadataCache] Get cache error: {e}")
|
||||
return None
|
||||
|
||||
async def set_fields(
|
||||
self,
|
||||
tenant_id: str,
|
||||
fields: list[dict[str, Any]],
|
||||
ttl: int | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
缓存字段定义
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
fields: 字段定义列表
|
||||
ttl: 过期时间(秒),默认 1小时
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
if not redis:
|
||||
return False
|
||||
|
||||
key = self._make_key(tenant_id)
|
||||
ttl = ttl or self.DEFAULT_TTL
|
||||
|
||||
await redis.setex(
|
||||
key,
|
||||
ttl,
|
||||
json.dumps(fields, ensure_ascii=False, default=str)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[MetadataCache] Cached {len(fields)} fields for tenant={tenant_id}, "
|
||||
f"ttl={ttl}s"
|
||||
)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[MetadataCache] Set cache error: {e}")
|
||||
return False
|
||||
|
||||
async def invalidate(self, tenant_id: str) -> bool:
|
||||
"""
|
||||
使缓存失效(字段定义更新时调用)
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
if not redis:
|
||||
return False
|
||||
|
||||
key = self._make_key(tenant_id)
|
||||
result = await redis.delete(key)
|
||||
|
||||
if result:
|
||||
logger.info(f"[MetadataCache] Invalidated cache for tenant={tenant_id}")
|
||||
return bool(result)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[MetadataCache] Invalidate error: {e}")
|
||||
return False
|
||||
|
||||
async def invalidate_all(self) -> bool:
|
||||
"""
|
||||
使所有元数据缓存失效
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not self._enabled:
|
||||
return False
|
||||
|
||||
try:
|
||||
redis = await self._get_redis()
|
||||
if not redis:
|
||||
return False
|
||||
|
||||
# 查找所有元数据缓存 key
|
||||
pattern = f"{self.CACHE_KEY_PREFIX}:*"
|
||||
keys = []
|
||||
async for key in redis.scan_iter(match=pattern):
|
||||
keys.append(key)
|
||||
|
||||
if keys:
|
||||
await redis.delete(*keys)
|
||||
logger.info(f"[MetadataCache] Invalidated {len(keys)} cache entries")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[MetadataCache] Invalidate all error: {e}")
|
||||
return False
|
||||
|
||||
|
||||
# 全局缓存服务实例
|
||||
_metadata_cache_service: MetadataCacheService | None = None
|
||||
|
||||
|
||||
async def get_metadata_cache_service() -> MetadataCacheService:
|
||||
"""获取元数据缓存服务实例(单例)"""
|
||||
global _metadata_cache_service
|
||||
if _metadata_cache_service is None:
|
||||
_metadata_cache_service = MetadataCacheService()
|
||||
return _metadata_cache_service
|
||||
|
|
@ -0,0 +1,342 @@
|
|||
"""
|
||||
Batch Ask-Back Service.
|
||||
批量追问服务 - 支持一次追问多个缺失槽位
|
||||
|
||||
[AC-MRS-SLOT-ASKBACK-01] 批量追问
|
||||
|
||||
职责:
|
||||
1. 支持一次追问多个缺失槽位
|
||||
2. 选择策略:必填优先、场景相关优先、最近未追问过优先
|
||||
3. 输出形式:单条自然语言合并提问或分段提问
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.services.cache.slot_state_cache import get_slot_state_cache
|
||||
from app.services.slot_definition_service import SlotDefinitionService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AskBackSlot:
|
||||
"""
|
||||
追问槽位信息
|
||||
|
||||
Attributes:
|
||||
slot_key: 槽位键名
|
||||
label: 显示标签
|
||||
ask_back_prompt: 追问提示
|
||||
priority: 优先级(数值越大越优先)
|
||||
last_asked_at: 上次追问时间戳
|
||||
is_required: 是否必填
|
||||
scene_relevance: 场景相关度
|
||||
"""
|
||||
slot_key: str
|
||||
label: str
|
||||
ask_back_prompt: str | None = None
|
||||
priority: int = 0
|
||||
last_asked_at: float | None = None
|
||||
is_required: bool = False
|
||||
scene_relevance: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchAskBackConfig:
|
||||
"""
|
||||
批量追问配置
|
||||
|
||||
Attributes:
|
||||
max_ask_back_slots_per_turn: 每轮最多追问槽位数
|
||||
prefer_required: 是否优先追问必填槽位
|
||||
prefer_scene_relevant: 是否优先追问场景相关槽位
|
||||
avoid_recent_asked: 是否避免最近追问过的槽位
|
||||
recent_asked_threshold_seconds: 最近追问阈值(秒)
|
||||
merge_prompts: 是否合并追问提示
|
||||
merge_template: 合并模板
|
||||
"""
|
||||
max_ask_back_slots_per_turn: int = 2
|
||||
prefer_required: bool = True
|
||||
prefer_scene_relevant: bool = True
|
||||
avoid_recent_asked: bool = True
|
||||
recent_asked_threshold_seconds: float = 60.0
|
||||
merge_prompts: bool = True
|
||||
merge_template: str = "为了更好地为您服务,请告诉我:{prompts}"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchAskBackResult:
|
||||
"""
|
||||
批量追问结果
|
||||
|
||||
Attributes:
|
||||
selected_slots: 选中的追问槽位列表
|
||||
prompts: 追问提示列表
|
||||
merged_prompt: 合并后的追问提示
|
||||
ask_back_count: 追问数量
|
||||
"""
|
||||
selected_slots: list[AskBackSlot] = field(default_factory=list)
|
||||
prompts: list[str] = field(default_factory=list)
|
||||
merged_prompt: str | None = None
|
||||
ask_back_count: int = 0
|
||||
|
||||
def has_ask_back(self) -> bool:
|
||||
return self.ask_back_count > 0
|
||||
|
||||
def get_prompt(self) -> str:
|
||||
"""获取最终追问提示"""
|
||||
if self.merged_prompt:
|
||||
return self.merged_prompt
|
||||
if self.prompts:
|
||||
return self.prompts[0]
|
||||
return "请提供更多信息以便我更好地帮助您。"
|
||||
|
||||
|
||||
class BatchAskBackService:
|
||||
"""
|
||||
[AC-MRS-SLOT-ASKBACK-01] 批量追问服务
|
||||
|
||||
支持一次追问多个缺失槽位,提高补全效率。
|
||||
"""
|
||||
|
||||
ASK_BACK_HISTORY_KEY_PREFIX = "slot_ask_back_history"
|
||||
ASK_BACK_HISTORY_TTL = 300
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
config: BatchAskBackConfig | None = None,
|
||||
):
|
||||
self._session = session
|
||||
self._tenant_id = tenant_id
|
||||
self._session_id = session_id
|
||||
self._config = config or BatchAskBackConfig()
|
||||
self._slot_def_service = SlotDefinitionService(session)
|
||||
self._cache = get_slot_state_cache()
|
||||
|
||||
async def generate_batch_ask_back(
|
||||
self,
|
||||
missing_slots: list[dict[str, str]],
|
||||
current_scene: str | None = None,
|
||||
) -> BatchAskBackResult:
|
||||
"""
|
||||
生成批量追问
|
||||
|
||||
Args:
|
||||
missing_slots: 缺失槽位列表
|
||||
current_scene: 当前场景
|
||||
|
||||
Returns:
|
||||
BatchAskBackResult: 批量追问结果
|
||||
"""
|
||||
if not missing_slots:
|
||||
return BatchAskBackResult()
|
||||
|
||||
ask_back_slots = await self._prepare_ask_back_slots(missing_slots, current_scene)
|
||||
|
||||
selected_slots = self._select_slots_for_ask_back(ask_back_slots)
|
||||
|
||||
asked_history = await self._get_asked_history()
|
||||
selected_slots = self._filter_recently_asked(selected_slots, asked_history)
|
||||
|
||||
if not selected_slots:
|
||||
selected_slots = ask_back_slots[:self._config.max_ask_back_slots_per_turn]
|
||||
|
||||
prompts = self._generate_prompts(selected_slots)
|
||||
merged_prompt = self._merge_prompts(prompts) if self._config.merge_prompts else None
|
||||
|
||||
await self._record_ask_back_history([s.slot_key for s in selected_slots])
|
||||
|
||||
return BatchAskBackResult(
|
||||
selected_slots=selected_slots,
|
||||
prompts=prompts,
|
||||
merged_prompt=merged_prompt,
|
||||
ask_back_count=len(selected_slots),
|
||||
)
|
||||
|
||||
async def _prepare_ask_back_slots(
|
||||
self,
|
||||
missing_slots: list[dict[str, str]],
|
||||
current_scene: str | None,
|
||||
) -> list[AskBackSlot]:
|
||||
"""准备追问槽位列表"""
|
||||
ask_back_slots = []
|
||||
|
||||
for missing in missing_slots:
|
||||
slot_key = missing.get("slot_key", "")
|
||||
label = missing.get("label", slot_key)
|
||||
ask_back_prompt = missing.get("ask_back_prompt")
|
||||
field_key = missing.get("field_key")
|
||||
|
||||
slot_def = await self._slot_def_service.get_slot_definition_by_key(
|
||||
self._tenant_id, slot_key
|
||||
)
|
||||
|
||||
is_required = False
|
||||
scene_relevance = 0.0
|
||||
|
||||
if slot_def:
|
||||
is_required = slot_def.required
|
||||
if not ask_back_prompt:
|
||||
ask_back_prompt = slot_def.ask_back_prompt
|
||||
|
||||
if current_scene and slot_def.scene_scope:
|
||||
if current_scene in slot_def.scene_scope:
|
||||
scene_relevance = 1.0
|
||||
|
||||
priority = self._calculate_priority(is_required, scene_relevance)
|
||||
|
||||
ask_back_slots.append(AskBackSlot(
|
||||
slot_key=slot_key,
|
||||
label=label,
|
||||
ask_back_prompt=ask_back_prompt,
|
||||
priority=priority,
|
||||
is_required=is_required,
|
||||
scene_relevance=scene_relevance,
|
||||
))
|
||||
|
||||
return ask_back_slots
|
||||
|
||||
def _calculate_priority(self, is_required: bool, scene_relevance: float) -> int:
|
||||
"""计算槽位优先级"""
|
||||
priority = 0
|
||||
|
||||
if self._config.prefer_required and is_required:
|
||||
priority += 100
|
||||
|
||||
if self._config.prefer_scene_relevant:
|
||||
priority += int(scene_relevance * 50)
|
||||
|
||||
return priority
|
||||
|
||||
def _select_slots_for_ask_back(
|
||||
self,
|
||||
ask_back_slots: list[AskBackSlot],
|
||||
) -> list[AskBackSlot]:
|
||||
"""选择要追问的槽位"""
|
||||
sorted_slots = sorted(
|
||||
ask_back_slots,
|
||||
key=lambda s: s.priority,
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
return sorted_slots[:self._config.max_ask_back_slots_per_turn]
|
||||
|
||||
async def _get_asked_history(self) -> dict[str, float]:
|
||||
"""获取最近追问历史"""
|
||||
history_key = f"{self.ASK_BACK_HISTORY_KEY_PREFIX}:{self._tenant_id}:{self._session_id}"
|
||||
|
||||
try:
|
||||
client = await self._cache._get_client()
|
||||
if client is None:
|
||||
return {}
|
||||
|
||||
import json
|
||||
data = await client.get(history_key)
|
||||
if data:
|
||||
return json.loads(data)
|
||||
except Exception as e:
|
||||
logger.warning(f"[BatchAskBack] Failed to get asked history: {e}")
|
||||
|
||||
return {}
|
||||
|
||||
async def _record_ask_back_history(self, slot_keys: list[str]) -> None:
|
||||
"""记录追问历史"""
|
||||
history_key = f"{self.ASK_BACK_HISTORY_KEY_PREFIX}:{self._tenant_id}:{self._session_id}"
|
||||
|
||||
try:
|
||||
client = await self._cache._get_client()
|
||||
if client is None:
|
||||
return
|
||||
|
||||
history = await self._get_asked_history()
|
||||
current_time = time.time()
|
||||
|
||||
for slot_key in slot_keys:
|
||||
history[slot_key] = current_time
|
||||
|
||||
import json
|
||||
await client.setex(
|
||||
history_key,
|
||||
self.ASK_BACK_HISTORY_TTL,
|
||||
json.dumps(history),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"[BatchAskBack] Failed to record asked history: {e}")
|
||||
|
||||
def _filter_recently_asked(
|
||||
self,
|
||||
slots: list[AskBackSlot],
|
||||
asked_history: dict[str, float],
|
||||
) -> list[AskBackSlot]:
|
||||
"""过滤最近追问过的槽位"""
|
||||
if not self._config.avoid_recent_asked:
|
||||
return slots
|
||||
|
||||
current_time = time.time()
|
||||
threshold = self._config.recent_asked_threshold_seconds
|
||||
|
||||
return [
|
||||
slot for slot in slots
|
||||
if slot.slot_key not in asked_history or
|
||||
current_time - asked_history[slot.slot_key] > threshold
|
||||
]
|
||||
|
||||
def _generate_prompts(self, slots: list[AskBackSlot]) -> list[str]:
|
||||
"""生成追问提示列表"""
|
||||
prompts = []
|
||||
|
||||
for slot in slots:
|
||||
if slot.ask_back_prompt:
|
||||
prompts.append(slot.ask_back_prompt)
|
||||
else:
|
||||
prompts.append(f"请告诉我您的{slot.label}")
|
||||
|
||||
return prompts
|
||||
|
||||
def _merge_prompts(self, prompts: list[str]) -> str | None:
|
||||
"""合并追问提示"""
|
||||
if not prompts:
|
||||
return None
|
||||
|
||||
if len(prompts) == 1:
|
||||
return prompts[0]
|
||||
|
||||
if len(prompts) == 2:
|
||||
return f"{prompts[0]},以及{prompts[1]}"
|
||||
|
||||
all_but_last = "、".join(prompts[:-1])
|
||||
return f"{all_but_last},以及{prompts[-1]}"
|
||||
|
||||
|
||||
def create_batch_ask_back_service(
|
||||
session: AsyncSession,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
config: BatchAskBackConfig | None = None,
|
||||
) -> BatchAskBackService:
|
||||
"""
|
||||
创建批量追问服务实例
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
tenant_id: 租户 ID
|
||||
session_id: 会话 ID
|
||||
config: 配置
|
||||
|
||||
Returns:
|
||||
BatchAskBackService: 批量追问服务实例
|
||||
"""
|
||||
return BatchAskBackService(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
session_id=session_id,
|
||||
config=config,
|
||||
)
|
||||
|
|
@ -0,0 +1,423 @@
|
|||
"""
|
||||
Scene Slot Bundle Loader Service.
|
||||
[AC-SCENE-SLOT-02] 运行时场景槽位包加载器
|
||||
[AC-SCENE-SLOT-03] 支持缓存层
|
||||
|
||||
职责:
|
||||
1. 根据场景标识加载槽位包配置
|
||||
2. 聚合槽位定义详情
|
||||
3. 计算缺失槽位
|
||||
4. 生成追问提示
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.entities import SceneSlotBundleStatus
|
||||
from app.services.scene_slot_bundle_service import SceneSlotBundleService
|
||||
from app.services.slot_definition_service import SlotDefinitionService
|
||||
from app.services.cache.scene_slot_bundle_cache import (
|
||||
CachedSceneSlotBundle,
|
||||
get_scene_slot_bundle_cache,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlotInfo:
|
||||
"""槽位信息"""
|
||||
slot_key: str
|
||||
type: str
|
||||
required: bool
|
||||
ask_back_prompt: str | None = None
|
||||
validation_rule: str | None = None
|
||||
linked_field_id: str | None = None
|
||||
default_value: Any = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SceneSlotContext:
|
||||
"""
|
||||
[AC-SCENE-SLOT-02] 场景槽位上下文
|
||||
|
||||
运行时使用的场景槽位包信息
|
||||
"""
|
||||
scene_key: str
|
||||
scene_name: str
|
||||
required_slots: list[SlotInfo] = field(default_factory=list)
|
||||
optional_slots: list[SlotInfo] = field(default_factory=list)
|
||||
slot_priority: list[str] = field(default_factory=list)
|
||||
completion_threshold: float = 1.0
|
||||
ask_back_order: str = "priority"
|
||||
|
||||
def get_all_slot_keys(self) -> list[str]:
|
||||
"""获取所有槽位键名"""
|
||||
return [s.slot_key for s in self.required_slots] + [s.slot_key for s in self.optional_slots]
|
||||
|
||||
def get_required_slot_keys(self) -> list[str]:
|
||||
"""获取必填槽位键名"""
|
||||
return [s.slot_key for s in self.required_slots]
|
||||
|
||||
def get_optional_slot_keys(self) -> list[str]:
|
||||
"""获取可选槽位键名"""
|
||||
return [s.slot_key for s in self.optional_slots]
|
||||
|
||||
def get_missing_slots(
|
||||
self,
|
||||
filled_slots: dict[str, Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取缺失的必填槽位信息
|
||||
|
||||
Args:
|
||||
filled_slots: 已填充的槽位值
|
||||
|
||||
Returns:
|
||||
缺失槽位信息列表
|
||||
"""
|
||||
missing = []
|
||||
|
||||
for slot_info in self.required_slots:
|
||||
if slot_info.slot_key not in filled_slots:
|
||||
missing.append({
|
||||
"slot_key": slot_info.slot_key,
|
||||
"type": slot_info.type,
|
||||
"required": True,
|
||||
"ask_back_prompt": slot_info.ask_back_prompt,
|
||||
"validation_rule": slot_info.validation_rule,
|
||||
"linked_field_id": slot_info.linked_field_id,
|
||||
})
|
||||
|
||||
return missing
|
||||
|
||||
def get_ordered_missing_slots(
|
||||
self,
|
||||
filled_slots: dict[str, Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
按优先级顺序获取缺失的必填槽位
|
||||
|
||||
Args:
|
||||
filled_slots: 已填充的槽位值
|
||||
|
||||
Returns:
|
||||
按优先级排序的缺失槽位信息列表
|
||||
"""
|
||||
missing = self.get_missing_slots(filled_slots)
|
||||
|
||||
if not missing:
|
||||
return []
|
||||
|
||||
if self.ask_back_order == "required_first":
|
||||
return missing
|
||||
|
||||
if self.ask_back_order == "priority" and self.slot_priority:
|
||||
priority_map = {slot_key: idx for idx, slot_key in enumerate(self.slot_priority)}
|
||||
missing.sort(key=lambda x: priority_map.get(x["slot_key"], 999))
|
||||
|
||||
return missing
|
||||
|
||||
def get_completion_ratio(
|
||||
self,
|
||||
filled_slots: dict[str, Any],
|
||||
) -> float:
|
||||
"""
|
||||
计算完成比例
|
||||
|
||||
Args:
|
||||
filled_slots: 已填充的槽位值
|
||||
|
||||
Returns:
|
||||
完成比例 (0.0 - 1.0)
|
||||
"""
|
||||
if not self.required_slots:
|
||||
return 1.0
|
||||
|
||||
filled_count = sum(
|
||||
1 for slot_info in self.required_slots
|
||||
if slot_info.slot_key in filled_slots
|
||||
)
|
||||
|
||||
return filled_count / len(self.required_slots)
|
||||
|
||||
def is_complete(
|
||||
self,
|
||||
filled_slots: dict[str, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
检查是否完成
|
||||
|
||||
Args:
|
||||
filled_slots: 已填充的槽位值
|
||||
|
||||
Returns:
|
||||
是否达到完成阈值
|
||||
"""
|
||||
return self.get_completion_ratio(filled_slots) >= self.completion_threshold
|
||||
|
||||
|
||||
class SceneSlotBundleLoader:
|
||||
"""
|
||||
[AC-SCENE-SLOT-02] 场景槽位包加载器
|
||||
[AC-SCENE-SLOT-03] 支持缓存层
|
||||
|
||||
运行时加载场景槽位包配置
|
||||
"""
|
||||
|
||||
def __init__(self, session: AsyncSession, use_cache: bool = True):
|
||||
self._session = session
|
||||
self._bundle_service = SceneSlotBundleService(session)
|
||||
self._slot_service = SlotDefinitionService(session)
|
||||
self._use_cache = use_cache
|
||||
self._cache = get_scene_slot_bundle_cache() if use_cache else None
|
||||
|
||||
async def load_scene_context(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
) -> SceneSlotContext | None:
|
||||
"""
|
||||
加载场景槽位上下文
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
|
||||
Returns:
|
||||
场景槽位上下文或 None
|
||||
"""
|
||||
# [AC-SCENE-SLOT-03] 尝试从缓存获取
|
||||
if self._use_cache and self._cache:
|
||||
cached = await self._cache.get(tenant_id, scene_key)
|
||||
if cached and cached.status == SceneSlotBundleStatus.ACTIVE.value:
|
||||
logger.debug(
|
||||
f"[AC-SCENE-SLOT-03] Cache hit for scene: {scene_key}"
|
||||
)
|
||||
return await self._build_context_from_cached(tenant_id, cached)
|
||||
|
||||
bundle = await self._bundle_service.get_active_bundle_by_scene(
|
||||
tenant_id=tenant_id,
|
||||
scene_key=scene_key,
|
||||
)
|
||||
|
||||
if not bundle:
|
||||
logger.debug(
|
||||
f"[AC-SCENE-SLOT-02] No active bundle found for scene: {scene_key}"
|
||||
)
|
||||
return None
|
||||
|
||||
# [AC-SCENE-SLOT-03] 写入缓存
|
||||
if self._use_cache and self._cache:
|
||||
cached_bundle = CachedSceneSlotBundle(
|
||||
scene_key=bundle.scene_key,
|
||||
scene_name=bundle.scene_name,
|
||||
description=bundle.description,
|
||||
required_slots=bundle.required_slots,
|
||||
optional_slots=bundle.optional_slots,
|
||||
slot_priority=bundle.slot_priority,
|
||||
completion_threshold=bundle.completion_threshold,
|
||||
ask_back_order=bundle.ask_back_order,
|
||||
status=bundle.status,
|
||||
version=bundle.version,
|
||||
)
|
||||
await self._cache.set(tenant_id, scene_key, cached_bundle)
|
||||
|
||||
return await self._build_context_from_bundle(tenant_id, bundle)
|
||||
|
||||
async def _build_context_from_cached(
|
||||
self,
|
||||
tenant_id: str,
|
||||
cached: CachedSceneSlotBundle,
|
||||
) -> SceneSlotContext:
|
||||
"""从缓存构建场景槽位上下文"""
|
||||
all_slots = await self._slot_service.list_slot_definitions(tenant_id)
|
||||
slot_map = {slot.slot_key: slot for slot in all_slots}
|
||||
|
||||
return self._build_context(cached, slot_map)
|
||||
|
||||
async def _build_context_from_bundle(
|
||||
self,
|
||||
tenant_id: str,
|
||||
bundle: Any,
|
||||
) -> SceneSlotContext:
|
||||
"""从数据库模型构建场景槽位上下文"""
|
||||
all_slots = await self._slot_service.list_slot_definitions(tenant_id)
|
||||
slot_map = {slot.slot_key: slot for slot in all_slots}
|
||||
|
||||
cached = CachedSceneSlotBundle(
|
||||
scene_key=bundle.scene_key,
|
||||
scene_name=bundle.scene_name,
|
||||
description=bundle.description,
|
||||
required_slots=bundle.required_slots,
|
||||
optional_slots=bundle.optional_slots,
|
||||
slot_priority=bundle.slot_priority,
|
||||
completion_threshold=bundle.completion_threshold,
|
||||
ask_back_order=bundle.ask_back_order,
|
||||
status=bundle.status,
|
||||
version=bundle.version,
|
||||
)
|
||||
|
||||
return self._build_context(cached, slot_map)
|
||||
|
||||
def _build_context(
|
||||
self,
|
||||
cached: CachedSceneSlotBundle,
|
||||
slot_map: dict[str, Any],
|
||||
) -> SceneSlotContext:
|
||||
"""构建场景槽位上下文"""
|
||||
required_slot_infos = []
|
||||
for slot_key in cached.required_slots:
|
||||
if slot_key in slot_map:
|
||||
slot_def = slot_map[slot_key]
|
||||
required_slot_infos.append(SlotInfo(
|
||||
slot_key=slot_def.slot_key,
|
||||
type=slot_def.type,
|
||||
required=True,
|
||||
ask_back_prompt=slot_def.ask_back_prompt,
|
||||
validation_rule=slot_def.validation_rule,
|
||||
linked_field_id=str(slot_def.linked_field_id) if slot_def.linked_field_id else None,
|
||||
default_value=slot_def.default_value,
|
||||
))
|
||||
else:
|
||||
logger.warning(
|
||||
f"[AC-SCENE-SLOT-02] Required slot not found: {slot_key}"
|
||||
)
|
||||
|
||||
optional_slot_infos = []
|
||||
for slot_key in cached.optional_slots:
|
||||
if slot_key in slot_map:
|
||||
slot_def = slot_map[slot_key]
|
||||
optional_slot_infos.append(SlotInfo(
|
||||
slot_key=slot_def.slot_key,
|
||||
type=slot_def.type,
|
||||
required=slot_def.required,
|
||||
ask_back_prompt=slot_def.ask_back_prompt,
|
||||
validation_rule=slot_def.validation_rule,
|
||||
linked_field_id=str(slot_def.linked_field_id) if slot_def.linked_field_id else None,
|
||||
default_value=slot_def.default_value,
|
||||
))
|
||||
else:
|
||||
logger.warning(
|
||||
f"[AC-SCENE-SLOT-02] Optional slot not found: {slot_key}"
|
||||
)
|
||||
|
||||
context = SceneSlotContext(
|
||||
scene_key=cached.scene_key,
|
||||
scene_name=cached.scene_name,
|
||||
required_slots=required_slot_infos,
|
||||
optional_slots=optional_slot_infos,
|
||||
slot_priority=cached.slot_priority or [],
|
||||
completion_threshold=cached.completion_threshold,
|
||||
ask_back_order=cached.ask_back_order,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-SCENE-SLOT-02] Loaded scene context: scene={cached.scene_key}, "
|
||||
f"required={len(required_slot_infos)}, optional={len(optional_slot_infos)}, "
|
||||
f"threshold={cached.completion_threshold}"
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
async def invalidate_cache(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
) -> bool:
|
||||
"""
|
||||
[AC-SCENE-SLOT-03] 使缓存失效
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if self._cache:
|
||||
return await self._cache.invalidate_on_update(tenant_id, scene_key)
|
||||
return True
|
||||
|
||||
async def get_missing_slots_for_scene(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
filled_slots: dict[str, Any],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
获取场景缺失的必填槽位
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
filled_slots: 已填充的槽位值
|
||||
|
||||
Returns:
|
||||
缺失槽位信息列表
|
||||
"""
|
||||
context = await self.load_scene_context(tenant_id, scene_key)
|
||||
|
||||
if not context:
|
||||
return []
|
||||
|
||||
return context.get_ordered_missing_slots(filled_slots)
|
||||
|
||||
async def generate_ask_back_prompt(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
filled_slots: dict[str, Any],
|
||||
max_slots: int = 2,
|
||||
) -> str | None:
|
||||
"""
|
||||
生成追问提示
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
filled_slots: 已填充的槽位值
|
||||
max_slots: 最多追问的槽位数量
|
||||
|
||||
Returns:
|
||||
追问提示或 None
|
||||
"""
|
||||
missing_slots = await self.get_missing_slots_for_scene(
|
||||
tenant_id=tenant_id,
|
||||
scene_key=scene_key,
|
||||
filled_slots=filled_slots,
|
||||
)
|
||||
|
||||
if not missing_slots:
|
||||
return None
|
||||
|
||||
context = await self.load_scene_context(tenant_id, scene_key)
|
||||
ask_back_order = context.ask_back_order if context else "priority"
|
||||
|
||||
if ask_back_order == "parallel":
|
||||
prompts = []
|
||||
for missing in missing_slots[:max_slots]:
|
||||
if missing.get("ask_back_prompt"):
|
||||
prompts.append(missing["ask_back_prompt"])
|
||||
else:
|
||||
slot_key = missing.get("slot_key", "相关信息")
|
||||
prompts.append(f"您的{slot_key}")
|
||||
|
||||
if len(prompts) == 1:
|
||||
return prompts[0]
|
||||
elif len(prompts) == 2:
|
||||
return f"为了更好地为您服务,请告诉我{prompts[0]}和{prompts[1]}。"
|
||||
else:
|
||||
all_but_last = "、".join(prompts[:-1])
|
||||
return f"为了更好地为您服务,请告诉我{all_but_last},以及{prompts[-1]}。"
|
||||
else:
|
||||
first_missing = missing_slots[0]
|
||||
ask_back_prompt = first_missing.get("ask_back_prompt")
|
||||
if ask_back_prompt:
|
||||
return ask_back_prompt
|
||||
|
||||
slot_key = first_missing.get("slot_key", "相关信息")
|
||||
return f"为了更好地为您提供帮助,请告诉我您的{slot_key}。"
|
||||
|
|
@ -0,0 +1,334 @@
|
|||
"""
|
||||
Scene Slot Bundle Metrics Service.
|
||||
[AC-SCENE-SLOT-04] 场景槽位包监控指标
|
||||
|
||||
职责:
|
||||
1. 收集场景槽位包相关的监控指标
|
||||
2. 提供告警检测接口
|
||||
3. 支持指标导出
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from collections import defaultdict
|
||||
import threading
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SceneSlotMetricPoint:
|
||||
"""单个指标数据点"""
|
||||
timestamp: datetime
|
||||
tenant_id: str
|
||||
scene_key: str
|
||||
metric_name: str
|
||||
value: float
|
||||
tags: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SceneSlotMetricsSummary:
|
||||
"""指标汇总"""
|
||||
total_requests: int = 0
|
||||
cache_hits: int = 0
|
||||
cache_misses: int = 0
|
||||
missing_slots_triggered: int = 0
|
||||
ask_back_triggered: int = 0
|
||||
scene_not_configured: int = 0
|
||||
slot_not_found: int = 0
|
||||
avg_completion_ratio: float = 0.0
|
||||
|
||||
@property
|
||||
def cache_hit_rate(self) -> float:
|
||||
if self.total_requests == 0:
|
||||
return 0.0
|
||||
return self.cache_hits / self.total_requests
|
||||
|
||||
|
||||
class SceneSlotMetricsCollector:
|
||||
"""
|
||||
[AC-SCENE-SLOT-04] 场景槽位包指标收集器
|
||||
|
||||
收集以下指标:
|
||||
- scene_slot_requests_total: 场景槽位请求总数
|
||||
- scene_slot_cache_hits: 缓存命中次数
|
||||
- scene_slot_cache_misses: 缓存未命中次数
|
||||
- scene_slot_missing_triggered: 缺失槽位触发次数
|
||||
- scene_slot_ask_back_triggered: 追问触发次数
|
||||
- scene_slot_not_configured: 场景未配置次数
|
||||
- scene_slot_not_found: 槽位未找到次数
|
||||
- scene_slot_completion_ratio: 槽位完成比例
|
||||
"""
|
||||
|
||||
def __init__(self, max_points: int = 10000):
|
||||
self._max_points = max_points
|
||||
self._points: list[SceneSlotMetricPoint] = []
|
||||
self._counters: dict[str, int] = defaultdict(int)
|
||||
self._lock = threading.Lock()
|
||||
self._start_time = datetime.utcnow()
|
||||
|
||||
def record(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
metric_name: str,
|
||||
value: float = 1.0,
|
||||
tags: dict[str, str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
记录指标
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
metric_name: 指标名称
|
||||
value: 指标值
|
||||
tags: 额外标签
|
||||
"""
|
||||
point = SceneSlotMetricPoint(
|
||||
timestamp=datetime.utcnow(),
|
||||
tenant_id=tenant_id,
|
||||
scene_key=scene_key,
|
||||
metric_name=metric_name,
|
||||
value=value,
|
||||
tags=tags or {},
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
self._points.append(point)
|
||||
if len(self._points) > self._max_points:
|
||||
self._points = self._points[-self._max_points:]
|
||||
|
||||
counter_key = f"{tenant_id}:{scene_key}:{metric_name}"
|
||||
self._counters[counter_key] += 1
|
||||
|
||||
def record_cache_hit(self, tenant_id: str, scene_key: str) -> None:
|
||||
"""记录缓存命中"""
|
||||
self.record(tenant_id, scene_key, "cache_hit")
|
||||
|
||||
def record_cache_miss(self, tenant_id: str, scene_key: str) -> None:
|
||||
"""记录缓存未命中"""
|
||||
self.record(tenant_id, scene_key, "cache_miss")
|
||||
|
||||
def record_missing_slots(self, tenant_id: str, scene_key: str, count: int = 1) -> None:
|
||||
"""记录缺失槽位触发"""
|
||||
self.record(tenant_id, scene_key, "missing_slots_triggered", float(count))
|
||||
|
||||
def record_ask_back(self, tenant_id: str, scene_key: str) -> None:
|
||||
"""记录追问触发"""
|
||||
self.record(tenant_id, scene_key, "ask_back_triggered")
|
||||
|
||||
def record_scene_not_configured(self, tenant_id: str, scene_key: str) -> None:
|
||||
"""记录场景未配置"""
|
||||
self.record(tenant_id, scene_key, "scene_not_configured")
|
||||
|
||||
def record_slot_not_found(self, tenant_id: str, scene_key: str, slot_key: str) -> None:
|
||||
"""记录槽位未找到"""
|
||||
self.record(tenant_id, scene_key, "slot_not_found", tags={"slot_key": slot_key})
|
||||
|
||||
def record_completion_ratio(self, tenant_id: str, scene_key: str, ratio: float) -> None:
|
||||
"""记录槽位完成比例"""
|
||||
self.record(tenant_id, scene_key, "completion_ratio", ratio)
|
||||
|
||||
def get_summary(self, tenant_id: str | None = None) -> SceneSlotMetricsSummary:
|
||||
"""
|
||||
获取指标汇总
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID(可选,为 None 时返回所有租户的汇总)
|
||||
|
||||
Returns:
|
||||
指标汇总
|
||||
"""
|
||||
summary = SceneSlotMetricsSummary()
|
||||
|
||||
with self._lock:
|
||||
points = self._points.copy()
|
||||
|
||||
for point in points:
|
||||
if tenant_id and point.tenant_id != tenant_id:
|
||||
continue
|
||||
|
||||
summary.total_requests += 1
|
||||
|
||||
if point.metric_name == "cache_hit":
|
||||
summary.cache_hits += 1
|
||||
elif point.metric_name == "cache_miss":
|
||||
summary.cache_misses += 1
|
||||
elif point.metric_name == "missing_slots_triggered":
|
||||
summary.missing_slots_triggered += int(point.value)
|
||||
elif point.metric_name == "ask_back_triggered":
|
||||
summary.ask_back_triggered += 1
|
||||
elif point.metric_name == "scene_not_configured":
|
||||
summary.scene_not_configured += 1
|
||||
elif point.metric_name == "slot_not_found":
|
||||
summary.slot_not_found += 1
|
||||
elif point.metric_name == "completion_ratio":
|
||||
summary.avg_completion_ratio = (
|
||||
summary.avg_completion_ratio * summary.total_requests + point.value
|
||||
) / (summary.total_requests + 1) if summary.total_requests > 0 else point.value
|
||||
|
||||
return summary
|
||||
|
||||
def get_metrics_by_scene(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
获取特定场景的指标
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
|
||||
Returns:
|
||||
指标字典
|
||||
"""
|
||||
metrics = {
|
||||
"requests": 0,
|
||||
"cache_hits": 0,
|
||||
"cache_misses": 0,
|
||||
"missing_slots_triggered": 0,
|
||||
"ask_back_triggered": 0,
|
||||
"slot_not_found": 0,
|
||||
"avg_completion_ratio": 0.0,
|
||||
}
|
||||
|
||||
with self._lock:
|
||||
points = [p for p in self._points if p.tenant_id == tenant_id and p.scene_key == scene_key]
|
||||
|
||||
if not points:
|
||||
return metrics
|
||||
|
||||
completion_ratios = []
|
||||
|
||||
for point in points:
|
||||
metrics["requests"] += 1
|
||||
if point.metric_name in metrics:
|
||||
if point.metric_name == "completion_ratio":
|
||||
completion_ratios.append(point.value)
|
||||
else:
|
||||
metrics[point.metric_name] += int(point.value)
|
||||
|
||||
if completion_ratios:
|
||||
metrics["avg_completion_ratio"] = sum(completion_ratios) / len(completion_ratios)
|
||||
|
||||
return metrics
|
||||
|
||||
def check_alerts(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
thresholds: dict[str, float] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""
|
||||
检查告警条件
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
thresholds: 告警阈值配置
|
||||
|
||||
Returns:
|
||||
告警列表
|
||||
"""
|
||||
default_thresholds = {
|
||||
"cache_hit_rate_low": 0.5,
|
||||
"missing_slots_rate_high": 0.3,
|
||||
"scene_not_configured_rate_high": 0.1,
|
||||
}
|
||||
|
||||
effective_thresholds = {**default_thresholds, **(thresholds or {})}
|
||||
|
||||
metrics = self.get_metrics_by_scene(tenant_id, scene_key)
|
||||
alerts = []
|
||||
|
||||
if metrics["requests"] > 0:
|
||||
cache_hit_rate = metrics["cache_hits"] / metrics["requests"]
|
||||
if cache_hit_rate < effective_thresholds["cache_hit_rate_low"]:
|
||||
alerts.append({
|
||||
"alert_type": "cache_hit_rate_low",
|
||||
"severity": "warning",
|
||||
"message": f"场景 {scene_key} 的缓存命中率 ({cache_hit_rate:.2%}) 低于阈值 ({effective_thresholds['cache_hit_rate_low']:.0%})",
|
||||
"current_value": cache_hit_rate,
|
||||
"threshold": effective_thresholds["cache_hit_rate_low"],
|
||||
"suggestion": "检查场景槽位包配置是否频繁变更,或增加缓存 TTL",
|
||||
})
|
||||
|
||||
missing_slots_rate = metrics["missing_slots_triggered"] / metrics["requests"]
|
||||
if missing_slots_rate > effective_thresholds["missing_slots_rate_high"]:
|
||||
alerts.append({
|
||||
"alert_type": "missing_slots_rate_high",
|
||||
"severity": "warning",
|
||||
"message": f"场景 {scene_key} 的缺失槽位触发率 ({missing_slots_rate:.2%}) 高于阈值 ({effective_thresholds['missing_slots_rate_high']:.0%})",
|
||||
"current_value": missing_slots_rate,
|
||||
"threshold": effective_thresholds["missing_slots_rate_high"],
|
||||
"suggestion": "检查槽位配置是否合理,或优化槽位提取策略",
|
||||
})
|
||||
|
||||
scene_not_configured_rate = metrics.get("scene_not_configured", 0) / metrics["requests"]
|
||||
if scene_not_configured_rate > effective_thresholds["scene_not_configured_rate_high"]:
|
||||
alerts.append({
|
||||
"alert_type": "scene_not_configured_rate_high",
|
||||
"severity": "error",
|
||||
"message": f"场景 {scene_key} 未配置率 ({scene_not_configured_rate:.2%}) 高于阈值 ({effective_thresholds['scene_not_configured_rate_high']:.0%})",
|
||||
"current_value": scene_not_configured_rate,
|
||||
"threshold": effective_thresholds["scene_not_configured_rate_high"],
|
||||
"suggestion": "请为该场景创建场景槽位包配置",
|
||||
})
|
||||
|
||||
return alerts
|
||||
|
||||
def export_metrics(self, tenant_id: str | None = None) -> dict[str, Any]:
|
||||
"""
|
||||
导出指标数据
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID(可选)
|
||||
|
||||
Returns:
|
||||
指标数据字典
|
||||
"""
|
||||
with self._lock:
|
||||
points = [
|
||||
{
|
||||
"timestamp": p.timestamp.isoformat(),
|
||||
"tenant_id": p.tenant_id,
|
||||
"scene_key": p.scene_key,
|
||||
"metric_name": p.metric_name,
|
||||
"value": p.value,
|
||||
"tags": p.tags,
|
||||
}
|
||||
for p in self._points
|
||||
if tenant_id is None or p.tenant_id == tenant_id
|
||||
]
|
||||
|
||||
return {
|
||||
"start_time": self._start_time.isoformat(),
|
||||
"end_time": datetime.utcnow().isoformat(),
|
||||
"total_points": len(points),
|
||||
"points": points,
|
||||
"summary": self.get_summary(tenant_id).__dict__,
|
||||
}
|
||||
|
||||
def reset(self) -> None:
|
||||
"""重置指标收集器"""
|
||||
with self._lock:
|
||||
self._points = []
|
||||
self._counters = defaultdict(int)
|
||||
self._start_time = datetime.utcnow()
|
||||
|
||||
|
||||
_metrics_collector: SceneSlotMetricsCollector | None = None
|
||||
|
||||
|
||||
def get_scene_slot_metrics_collector() -> SceneSlotMetricsCollector:
|
||||
"""获取场景槽位指标收集器实例"""
|
||||
global _metrics_collector
|
||||
if _metrics_collector is None:
|
||||
_metrics_collector = SceneSlotMetricsCollector()
|
||||
return _metrics_collector
|
||||
|
|
@ -0,0 +1,500 @@
|
|||
"""
|
||||
Slot Backfill Service.
|
||||
槽位回填服务 - 处理槽位值的提取、校验、确认、写回
|
||||
|
||||
[AC-MRS-SLOT-BACKFILL-01] 槽位值回填确认
|
||||
|
||||
职责:
|
||||
1. 从用户回复提取候选槽位值
|
||||
2. 调用 SlotManager 校验并归一化
|
||||
3. 校验失败返回 ask_back_prompt 二次追问
|
||||
4. 校验通过写入状态并标记 source/confidence
|
||||
5. 对低置信度值增加确认话术
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.mid.schemas import SlotSource
|
||||
from app.services.mid.slot_manager import SlotManager, SlotWriteResult
|
||||
from app.services.mid.slot_state_aggregator import SlotStateAggregator, SlotState
|
||||
from app.services.mid.slot_strategy_executor import (
|
||||
ExtractContext,
|
||||
SlotStrategyExecutor,
|
||||
StrategyChainResult,
|
||||
)
|
||||
from app.services.slot_definition_service import SlotDefinitionService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BackfillStatus(str, Enum):
|
||||
"""回填状态"""
|
||||
SUCCESS = "success"
|
||||
VALIDATION_FAILED = "validation_failed"
|
||||
EXTRACTION_FAILED = "extraction_failed"
|
||||
NEEDS_CONFIRMATION = "needs_confirmation"
|
||||
NO_CANDIDATES = "no_candidates"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BackfillResult:
|
||||
"""
|
||||
回填结果
|
||||
|
||||
Attributes:
|
||||
status: 回填状态
|
||||
slot_key: 槽位键名
|
||||
value: 最终值(校验通过后)
|
||||
normalized_value: 归一化后的值
|
||||
source: 值来源
|
||||
confidence: 置信度
|
||||
error_message: 错误信息
|
||||
ask_back_prompt: 追问提示
|
||||
confirmation_prompt: 确认提示(低置信度时)
|
||||
updated_state: 更新后的槽位状态
|
||||
"""
|
||||
status: BackfillStatus
|
||||
slot_key: str | None = None
|
||||
value: Any = None
|
||||
normalized_value: Any = None
|
||||
source: str = "unknown"
|
||||
confidence: float = 0.0
|
||||
error_message: str | None = None
|
||||
ask_back_prompt: str | None = None
|
||||
confirmation_prompt: str | None = None
|
||||
updated_state: SlotState | None = None
|
||||
|
||||
def is_success(self) -> bool:
|
||||
return self.status == BackfillStatus.SUCCESS
|
||||
|
||||
def needs_ask_back(self) -> bool:
|
||||
return self.status in (
|
||||
BackfillStatus.VALIDATION_FAILED,
|
||||
BackfillStatus.EXTRACTION_FAILED,
|
||||
)
|
||||
|
||||
def needs_confirmation(self) -> bool:
|
||||
return self.status == BackfillStatus.NEEDS_CONFIRMATION
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"status": self.status.value,
|
||||
"slot_key": self.slot_key,
|
||||
"value": self.value,
|
||||
"normalized_value": self.normalized_value,
|
||||
"source": self.source,
|
||||
"confidence": self.confidence,
|
||||
"error_message": self.error_message,
|
||||
"ask_back_prompt": self.ask_back_prompt,
|
||||
"confirmation_prompt": self.confirmation_prompt,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchBackfillResult:
|
||||
"""批量回填结果"""
|
||||
results: list[BackfillResult] = field(default_factory=list)
|
||||
success_count: int = 0
|
||||
failed_count: int = 0
|
||||
confirmation_needed_count: int = 0
|
||||
|
||||
def add_result(self, result: BackfillResult) -> None:
|
||||
self.results.append(result)
|
||||
if result.is_success():
|
||||
self.success_count += 1
|
||||
elif result.needs_confirmation():
|
||||
self.confirmation_needed_count += 1
|
||||
else:
|
||||
self.failed_count += 1
|
||||
|
||||
def get_ask_back_prompts(self) -> list[str]:
|
||||
"""获取所有追问提示"""
|
||||
return [
|
||||
r.ask_back_prompt
|
||||
for r in self.results
|
||||
if r.ask_back_prompt
|
||||
]
|
||||
|
||||
def get_confirmation_prompts(self) -> list[str]:
|
||||
"""获取所有确认提示"""
|
||||
return [
|
||||
r.confirmation_prompt
|
||||
for r in self.results
|
||||
if r.confirmation_prompt
|
||||
]
|
||||
|
||||
|
||||
class SlotBackfillService:
|
||||
"""
|
||||
[AC-MRS-SLOT-BACKFILL-01] 槽位回填服务
|
||||
|
||||
处理槽位值的提取、校验、确认、写回流程:
|
||||
1. 从用户回复提取候选槽位值
|
||||
2. SlotManager 校验并归一化
|
||||
3. 校验失败返回 ask_back_prompt 二次追问
|
||||
4. 校验通过写入状态并标记 source/confidence
|
||||
5. 对低置信度值增加确认话术
|
||||
"""
|
||||
|
||||
CONFIDENCE_THRESHOLD_LOW = 0.5
|
||||
CONFIDENCE_THRESHOLD_HIGH = 0.8
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
tenant_id: str,
|
||||
session_id: str | None = None,
|
||||
slot_manager: SlotManager | None = None,
|
||||
strategy_executor: SlotStrategyExecutor | None = None,
|
||||
):
|
||||
self._session = session
|
||||
self._tenant_id = tenant_id
|
||||
self._session_id = session_id
|
||||
self._slot_manager = slot_manager or SlotManager(session=session)
|
||||
self._strategy_executor = strategy_executor or SlotStrategyExecutor()
|
||||
self._slot_def_service = SlotDefinitionService(session)
|
||||
self._state_aggregator: SlotStateAggregator | None = None
|
||||
|
||||
async def _get_state_aggregator(self) -> SlotStateAggregator:
|
||||
"""获取状态聚合器"""
|
||||
if self._state_aggregator is None:
|
||||
self._state_aggregator = SlotStateAggregator(
|
||||
session=self._session,
|
||||
tenant_id=self._tenant_id,
|
||||
session_id=self._session_id,
|
||||
)
|
||||
return self._state_aggregator
|
||||
|
||||
async def backfill_single_slot(
|
||||
self,
|
||||
slot_key: str,
|
||||
candidate_value: Any,
|
||||
source: str = "user_confirmed",
|
||||
confidence: float = 1.0,
|
||||
strategies: list[str] | None = None,
|
||||
) -> BackfillResult:
|
||||
"""
|
||||
回填单个槽位
|
||||
|
||||
执行流程:
|
||||
1. 如果有提取策略,执行提取
|
||||
2. 校验候选值
|
||||
3. 根据校验结果决定下一步
|
||||
4. 写入状态
|
||||
|
||||
Args:
|
||||
slot_key: 槽位键名
|
||||
candidate_value: 候选值
|
||||
source: 值来源
|
||||
confidence: 初始置信度
|
||||
strategies: 提取策略链(可选)
|
||||
|
||||
Returns:
|
||||
BackfillResult: 回填结果
|
||||
"""
|
||||
final_value = candidate_value
|
||||
final_source = source
|
||||
final_confidence = confidence
|
||||
|
||||
if strategies:
|
||||
extracted_result = await self._extract_value(
|
||||
slot_key=slot_key,
|
||||
user_input=str(candidate_value),
|
||||
strategies=strategies,
|
||||
)
|
||||
|
||||
if extracted_result.success:
|
||||
final_value = extracted_result.final_value
|
||||
final_source = extracted_result.final_strategy or source
|
||||
final_confidence = self._get_confidence_for_strategy(final_source)
|
||||
else:
|
||||
ask_back_prompt = await self._slot_manager.get_ask_back_prompt(
|
||||
self._tenant_id, slot_key
|
||||
)
|
||||
return BackfillResult(
|
||||
status=BackfillStatus.EXTRACTION_FAILED,
|
||||
slot_key=slot_key,
|
||||
value=candidate_value,
|
||||
source=source,
|
||||
confidence=confidence,
|
||||
error_message=extracted_result.steps[-1].failure_reason if extracted_result.steps else "提取失败",
|
||||
ask_back_prompt=ask_back_prompt,
|
||||
)
|
||||
|
||||
write_result = await self._slot_manager.write_slot(
|
||||
tenant_id=self._tenant_id,
|
||||
slot_key=slot_key,
|
||||
value=final_value,
|
||||
source=SlotSource(final_source) if final_source in [s.value for s in SlotSource] else SlotSource.USER_CONFIRMED,
|
||||
confidence=final_confidence,
|
||||
)
|
||||
|
||||
if not write_result.success:
|
||||
return BackfillResult(
|
||||
status=BackfillStatus.VALIDATION_FAILED,
|
||||
slot_key=slot_key,
|
||||
value=final_value,
|
||||
source=final_source,
|
||||
confidence=final_confidence,
|
||||
error_message=write_result.error.error_message if write_result.error else "校验失败",
|
||||
ask_back_prompt=write_result.ask_back_prompt,
|
||||
)
|
||||
|
||||
normalized_value = write_result.value
|
||||
updated_state = None
|
||||
|
||||
if self._session_id:
|
||||
aggregator = await self._get_state_aggregator()
|
||||
updated_state = await aggregator.update_slot(
|
||||
slot_key=slot_key,
|
||||
value=normalized_value,
|
||||
source=final_source,
|
||||
confidence=final_confidence,
|
||||
)
|
||||
|
||||
result_status = BackfillStatus.SUCCESS
|
||||
confirmation_prompt = None
|
||||
|
||||
if final_confidence < self.CONFIDENCE_THRESHOLD_LOW:
|
||||
result_status = BackfillStatus.NEEDS_CONFIRMATION
|
||||
confirmation_prompt = self._generate_confirmation_prompt(
|
||||
slot_key, normalized_value
|
||||
)
|
||||
|
||||
return BackfillResult(
|
||||
status=result_status,
|
||||
slot_key=slot_key,
|
||||
value=final_value,
|
||||
normalized_value=normalized_value,
|
||||
source=final_source,
|
||||
confidence=final_confidence,
|
||||
confirmation_prompt=confirmation_prompt,
|
||||
updated_state=updated_state,
|
||||
)
|
||||
|
||||
async def backfill_multiple_slots(
|
||||
self,
|
||||
candidates: dict[str, Any],
|
||||
source: str = "user_confirmed",
|
||||
confidence: float = 1.0,
|
||||
) -> BatchBackfillResult:
|
||||
"""
|
||||
批量回填槽位
|
||||
|
||||
Args:
|
||||
candidates: 候选值字典 {slot_key: value}
|
||||
source: 值来源
|
||||
confidence: 初始置信度
|
||||
|
||||
Returns:
|
||||
BatchBackfillResult: 批量回填结果
|
||||
"""
|
||||
batch_result = BatchBackfillResult()
|
||||
|
||||
for slot_key, value in candidates.items():
|
||||
result = await self.backfill_single_slot(
|
||||
slot_key=slot_key,
|
||||
candidate_value=value,
|
||||
source=source,
|
||||
confidence=confidence,
|
||||
)
|
||||
batch_result.add_result(result)
|
||||
|
||||
return batch_result
|
||||
|
||||
async def backfill_from_user_response(
|
||||
self,
|
||||
user_response: str,
|
||||
expected_slots: list[str],
|
||||
strategies: list[str] | None = None,
|
||||
) -> BatchBackfillResult:
|
||||
"""
|
||||
从用户回复中提取并回填槽位
|
||||
|
||||
Args:
|
||||
user_response: 用户回复文本
|
||||
expected_slots: 期望提取的槽位列表
|
||||
strategies: 提取策略链
|
||||
|
||||
Returns:
|
||||
BatchBackfillResult: 批量回填结果
|
||||
"""
|
||||
batch_result = BatchBackfillResult()
|
||||
|
||||
for slot_key in expected_slots:
|
||||
slot_def = await self._slot_def_service.get_slot_definition_by_key(
|
||||
self._tenant_id, slot_key
|
||||
)
|
||||
|
||||
if not slot_def:
|
||||
continue
|
||||
|
||||
extract_strategies = strategies or ["rule", "llm"]
|
||||
|
||||
extracted_result = await self._extract_value(
|
||||
slot_key=slot_key,
|
||||
user_input=user_response,
|
||||
strategies=extract_strategies,
|
||||
slot_type=slot_def.type,
|
||||
validation_rule=slot_def.validation_rule,
|
||||
)
|
||||
|
||||
if not extracted_result.success:
|
||||
ask_back_prompt = slot_def.ask_back_prompt or f"请提供{slot_key}信息"
|
||||
batch_result.add_result(BackfillResult(
|
||||
status=BackfillStatus.EXTRACTION_FAILED,
|
||||
slot_key=slot_key,
|
||||
error_message="无法从回复中提取",
|
||||
ask_back_prompt=ask_back_prompt,
|
||||
))
|
||||
continue
|
||||
|
||||
source = self._get_source_for_strategy(extracted_result.final_strategy)
|
||||
confidence = self._get_confidence_for_strategy(source)
|
||||
|
||||
result = await self.backfill_single_slot(
|
||||
slot_key=slot_key,
|
||||
candidate_value=extracted_result.final_value,
|
||||
source=source,
|
||||
confidence=confidence,
|
||||
)
|
||||
batch_result.add_result(result)
|
||||
|
||||
return batch_result
|
||||
|
||||
async def _extract_value(
|
||||
self,
|
||||
slot_key: str,
|
||||
user_input: str,
|
||||
strategies: list[str],
|
||||
slot_type: str = "string",
|
||||
validation_rule: str | None = None,
|
||||
) -> StrategyChainResult:
|
||||
"""
|
||||
执行槽位值提取
|
||||
|
||||
Args:
|
||||
slot_key: 槽位键名
|
||||
user_input: 用户输入
|
||||
strategies: 提取策略链
|
||||
slot_type: 槽位类型
|
||||
validation_rule: 校验规则
|
||||
|
||||
Returns:
|
||||
StrategyChainResult: 提取结果
|
||||
"""
|
||||
context = ExtractContext(
|
||||
tenant_id=self._tenant_id,
|
||||
slot_key=slot_key,
|
||||
user_input=user_input,
|
||||
slot_type=slot_type,
|
||||
validation_rule=validation_rule,
|
||||
)
|
||||
|
||||
return await self._strategy_executor.execute_chain(
|
||||
strategies=strategies,
|
||||
context=context,
|
||||
)
|
||||
|
||||
def _get_source_for_strategy(self, strategy: str | None) -> str:
|
||||
"""根据策略获取来源"""
|
||||
strategy_source_map = {
|
||||
"rule": SlotSource.RULE_EXTRACTED.value,
|
||||
"llm": SlotSource.LLM_INFERRED.value,
|
||||
"user_input": SlotSource.USER_CONFIRMED.value,
|
||||
}
|
||||
return strategy_source_map.get(strategy or "", "unknown")
|
||||
|
||||
def _get_confidence_for_strategy(self, source: str) -> float:
|
||||
"""根据来源获取置信度"""
|
||||
confidence_map = {
|
||||
SlotSource.USER_CONFIRMED.value: 1.0,
|
||||
SlotSource.RULE_EXTRACTED.value: 0.9,
|
||||
SlotSource.LLM_INFERRED.value: 0.7,
|
||||
"context": 0.5,
|
||||
SlotSource.DEFAULT.value: 0.3,
|
||||
}
|
||||
return confidence_map.get(source, 0.5)
|
||||
|
||||
def _generate_confirmation_prompt(
|
||||
self,
|
||||
slot_key: str,
|
||||
value: Any,
|
||||
) -> str:
|
||||
"""生成确认提示"""
|
||||
return f"我理解您说的是「{value}」,对吗?"
|
||||
|
||||
async def confirm_low_confidence_slot(
|
||||
self,
|
||||
slot_key: str,
|
||||
confirmed: bool,
|
||||
) -> BackfillResult:
|
||||
"""
|
||||
确认低置信度槽位
|
||||
|
||||
Args:
|
||||
slot_key: 槽位键名
|
||||
confirmed: 用户是否确认
|
||||
|
||||
Returns:
|
||||
BackfillResult: 确认结果
|
||||
"""
|
||||
if not self._session_id:
|
||||
return BackfillResult(
|
||||
status=BackfillStatus.SUCCESS,
|
||||
slot_key=slot_key,
|
||||
)
|
||||
|
||||
aggregator = await self._get_state_aggregator()
|
||||
|
||||
if confirmed:
|
||||
updated_state = await aggregator.update_slot(
|
||||
slot_key=slot_key,
|
||||
source=SlotSource.USER_CONFIRMED.value,
|
||||
confidence=1.0,
|
||||
)
|
||||
return BackfillResult(
|
||||
status=BackfillStatus.SUCCESS,
|
||||
slot_key=slot_key,
|
||||
source=SlotSource.USER_CONFIRMED.value,
|
||||
confidence=1.0,
|
||||
updated_state=updated_state,
|
||||
)
|
||||
else:
|
||||
await aggregator.clear_slot(slot_key)
|
||||
ask_back_prompt = await self._slot_manager.get_ask_back_prompt(
|
||||
self._tenant_id, slot_key
|
||||
)
|
||||
return BackfillResult(
|
||||
status=BackfillStatus.VALIDATION_FAILED,
|
||||
slot_key=slot_key,
|
||||
ask_back_prompt=ask_back_prompt or f"请重新提供{slot_key}信息",
|
||||
)
|
||||
|
||||
|
||||
def create_slot_backfill_service(
|
||||
session: AsyncSession,
|
||||
tenant_id: str,
|
||||
session_id: str | None = None,
|
||||
) -> SlotBackfillService:
|
||||
"""
|
||||
创建槽位回填服务实例
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
tenant_id: 租户 ID
|
||||
session_id: 会话 ID
|
||||
|
||||
Returns:
|
||||
SlotBackfillService: 槽位回填服务实例
|
||||
"""
|
||||
return SlotBackfillService(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
|
@ -0,0 +1,368 @@
|
|||
"""
|
||||
Slot Extraction Integration Service.
|
||||
槽位提取集成服务 - 将自动提取能力接入主链路
|
||||
|
||||
[AC-MRS-SLOT-EXTRACT-01] slot extraction 集成
|
||||
|
||||
职责:
|
||||
1. 接入点:memory_recall 之后、KB 检索之前
|
||||
2. 执行策略链:rule -> llm -> user_input
|
||||
3. 抽取结果统一走 SlotManager 校验
|
||||
4. 提供 trace:extracted_slots、validation_pass/fail、ask_back_triggered
|
||||
"""
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.mid.schemas import SlotSource
|
||||
from app.services.mid.slot_backfill_service import (
|
||||
BackfillResult,
|
||||
BackfillStatus,
|
||||
SlotBackfillService,
|
||||
)
|
||||
from app.services.mid.slot_manager import SlotManager
|
||||
from app.services.mid.slot_state_aggregator import SlotState, SlotStateAggregator
|
||||
from app.services.mid.slot_strategy_executor import (
|
||||
ExtractContext,
|
||||
SlotStrategyExecutor,
|
||||
StrategyChainResult,
|
||||
)
|
||||
from app.services.slot_definition_service import SlotDefinitionService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionTrace:
|
||||
"""
|
||||
提取追踪信息
|
||||
|
||||
Attributes:
|
||||
slot_key: 槽位键名
|
||||
strategy: 使用的策略
|
||||
extracted_value: 提取的值
|
||||
validation_passed: 校验是否通过
|
||||
final_value: 最终值(校验后)
|
||||
execution_time_ms: 执行时间
|
||||
failure_reason: 失败原因
|
||||
"""
|
||||
slot_key: str
|
||||
strategy: str | None = None
|
||||
extracted_value: Any = None
|
||||
validation_passed: bool = False
|
||||
final_value: Any = None
|
||||
execution_time_ms: float = 0.0
|
||||
failure_reason: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"slot_key": self.slot_key,
|
||||
"strategy": self.strategy,
|
||||
"extracted_value": self.extracted_value,
|
||||
"validation_passed": self.validation_passed,
|
||||
"final_value": self.final_value,
|
||||
"execution_time_ms": self.execution_time_ms,
|
||||
"failure_reason": self.failure_reason,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractionResult:
|
||||
"""
|
||||
提取结果
|
||||
|
||||
Attributes:
|
||||
success: 是否成功
|
||||
extracted_slots: 成功提取的槽位
|
||||
failed_slots: 提取失败的槽位
|
||||
traces: 提取追踪信息列表
|
||||
total_execution_time_ms: 总执行时间
|
||||
ask_back_triggered: 是否触发追问
|
||||
ask_back_prompts: 追问提示列表
|
||||
"""
|
||||
success: bool = False
|
||||
extracted_slots: dict[str, Any] = field(default_factory=dict)
|
||||
failed_slots: list[str] = field(default_factory=list)
|
||||
traces: list[ExtractionTrace] = field(default_factory=list)
|
||||
total_execution_time_ms: float = 0.0
|
||||
ask_back_triggered: bool = False
|
||||
ask_back_prompts: list[str] = field(default_factory=list)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"success": self.success,
|
||||
"extracted_slots": self.extracted_slots,
|
||||
"failed_slots": self.failed_slots,
|
||||
"traces": [t.to_dict() for t in self.traces],
|
||||
"total_execution_time_ms": self.total_execution_time_ms,
|
||||
"ask_back_triggered": self.ask_back_triggered,
|
||||
"ask_back_prompts": self.ask_back_prompts,
|
||||
}
|
||||
|
||||
|
||||
class SlotExtractionIntegration:
|
||||
"""
|
||||
[AC-MRS-SLOT-EXTRACT-01] 槽位提取集成服务
|
||||
|
||||
将自动提取能力接入主链路:
|
||||
- 接入点:memory_recall 之后、KB 检索之前
|
||||
- 执行策略链:rule -> llm -> user_input
|
||||
- 抽取结果统一走 SlotManager 校验
|
||||
- 提供 trace
|
||||
"""
|
||||
|
||||
DEFAULT_STRATEGIES = ["rule", "llm"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
tenant_id: str,
|
||||
session_id: str | None = None,
|
||||
slot_manager: SlotManager | None = None,
|
||||
strategy_executor: SlotStrategyExecutor | None = None,
|
||||
):
|
||||
self._session = session
|
||||
self._tenant_id = tenant_id
|
||||
self._session_id = session_id
|
||||
self._slot_manager = slot_manager or SlotManager(session=session)
|
||||
self._strategy_executor = strategy_executor or SlotStrategyExecutor()
|
||||
self._slot_def_service = SlotDefinitionService(session)
|
||||
self._backfill_service: SlotBackfillService | None = None
|
||||
|
||||
async def _get_backfill_service(self) -> SlotBackfillService:
|
||||
"""获取回填服务"""
|
||||
if self._backfill_service is None:
|
||||
self._backfill_service = SlotBackfillService(
|
||||
session=self._session,
|
||||
tenant_id=self._tenant_id,
|
||||
session_id=self._session_id,
|
||||
slot_manager=self._slot_manager,
|
||||
strategy_executor=self._strategy_executor,
|
||||
)
|
||||
return self._backfill_service
|
||||
|
||||
async def extract_and_fill(
|
||||
self,
|
||||
user_input: str,
|
||||
target_slots: list[str] | None = None,
|
||||
strategies: list[str] | None = None,
|
||||
slot_state: SlotState | None = None,
|
||||
) -> ExtractionResult:
|
||||
"""
|
||||
执行提取并填充槽位
|
||||
|
||||
Args:
|
||||
user_input: 用户输入
|
||||
target_slots: 目标槽位列表(为空则提取所有必填槽位)
|
||||
strategies: 提取策略链(默认 rule -> llm)
|
||||
slot_state: 当前槽位状态(用于识别缺失槽位)
|
||||
|
||||
Returns:
|
||||
ExtractionResult: 提取结果
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
strategies = strategies or self.DEFAULT_STRATEGIES
|
||||
|
||||
if target_slots is None:
|
||||
target_slots = await self._get_missing_required_slots(slot_state)
|
||||
|
||||
if not target_slots:
|
||||
return ExtractionResult(success=True)
|
||||
|
||||
result = ExtractionResult()
|
||||
|
||||
for slot_key in target_slots:
|
||||
trace = await self._extract_single_slot(
|
||||
slot_key=slot_key,
|
||||
user_input=user_input,
|
||||
strategies=strategies,
|
||||
)
|
||||
result.traces.append(trace)
|
||||
|
||||
if trace.validation_passed and trace.final_value is not None:
|
||||
result.extracted_slots[slot_key] = trace.final_value
|
||||
else:
|
||||
result.failed_slots.append(slot_key)
|
||||
if trace.failure_reason:
|
||||
ask_back_prompt = await self._slot_manager.get_ask_back_prompt(
|
||||
self._tenant_id, slot_key
|
||||
)
|
||||
if ask_back_prompt:
|
||||
result.ask_back_prompts.append(ask_back_prompt)
|
||||
|
||||
result.total_execution_time_ms = (time.time() - start_time) * 1000
|
||||
result.success = len(result.extracted_slots) > 0
|
||||
result.ask_back_triggered = len(result.ask_back_prompts) > 0
|
||||
|
||||
if result.extracted_slots and self._session_id:
|
||||
await self._save_extracted_slots(result.extracted_slots)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-EXTRACT-01] Extraction completed: "
|
||||
f"tenant={self._tenant_id}, extracted={len(result.extracted_slots)}, "
|
||||
f"failed={len(result.failed_slots)}, time_ms={result.total_execution_time_ms:.2f}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _extract_single_slot(
|
||||
self,
|
||||
slot_key: str,
|
||||
user_input: str,
|
||||
strategies: list[str],
|
||||
) -> ExtractionTrace:
|
||||
"""提取单个槽位"""
|
||||
start_time = time.time()
|
||||
trace = ExtractionTrace(slot_key=slot_key)
|
||||
|
||||
slot_def = await self._slot_def_service.get_slot_definition_by_key(
|
||||
self._tenant_id, slot_key
|
||||
)
|
||||
|
||||
if not slot_def:
|
||||
trace.failure_reason = "Slot definition not found"
|
||||
trace.execution_time_ms = (time.time() - start_time) * 1000
|
||||
return trace
|
||||
|
||||
context = ExtractContext(
|
||||
tenant_id=self._tenant_id,
|
||||
slot_key=slot_key,
|
||||
user_input=user_input,
|
||||
slot_type=slot_def.type,
|
||||
validation_rule=slot_def.validation_rule,
|
||||
session_id=self._session_id,
|
||||
)
|
||||
|
||||
chain_result = await self._strategy_executor.execute_chain(
|
||||
strategies=strategies,
|
||||
context=context,
|
||||
ask_back_prompt=slot_def.ask_back_prompt,
|
||||
)
|
||||
|
||||
trace.strategy = chain_result.final_strategy
|
||||
trace.extracted_value = chain_result.final_value
|
||||
trace.execution_time_ms = (time.time() - start_time) * 1000
|
||||
|
||||
if not chain_result.success:
|
||||
trace.failure_reason = "Extraction failed"
|
||||
if chain_result.steps:
|
||||
last_step = chain_result.steps[-1]
|
||||
trace.failure_reason = last_step.failure_reason
|
||||
return trace
|
||||
|
||||
backfill_service = await self._get_backfill_service()
|
||||
source = self._get_source_for_strategy(chain_result.final_strategy)
|
||||
|
||||
backfill_result = await backfill_service.backfill_single_slot(
|
||||
slot_key=slot_key,
|
||||
candidate_value=chain_result.final_value,
|
||||
source=source,
|
||||
confidence=self._get_confidence_for_source(source),
|
||||
)
|
||||
|
||||
trace.validation_passed = backfill_result.is_success()
|
||||
trace.final_value = backfill_result.normalized_value
|
||||
|
||||
if not backfill_result.is_success():
|
||||
trace.failure_reason = backfill_result.error_message or "Validation failed"
|
||||
|
||||
return trace
|
||||
|
||||
async def _get_missing_required_slots(
|
||||
self,
|
||||
slot_state: SlotState | None,
|
||||
) -> list[str]:
|
||||
"""获取缺失的必填槽位"""
|
||||
if slot_state and slot_state.missing_required_slots:
|
||||
return [
|
||||
s.get("slot_key")
|
||||
for s in slot_state.missing_required_slots
|
||||
if s.get("slot_key")
|
||||
]
|
||||
|
||||
required_defs = await self._slot_def_service.list_slot_definitions(
|
||||
tenant_id=self._tenant_id,
|
||||
required=True,
|
||||
)
|
||||
|
||||
return [d.slot_key for d in required_defs]
|
||||
|
||||
async def _save_extracted_slots(
|
||||
self,
|
||||
extracted_slots: dict[str, Any],
|
||||
) -> None:
|
||||
"""保存提取的槽位到缓存"""
|
||||
if not self._session_id:
|
||||
return
|
||||
|
||||
aggregator = SlotStateAggregator(
|
||||
session=self._session,
|
||||
tenant_id=self._tenant_id,
|
||||
session_id=self._session_id,
|
||||
)
|
||||
|
||||
for slot_key, value in extracted_slots.items():
|
||||
await aggregator.update_slot(
|
||||
slot_key=slot_key,
|
||||
value=value,
|
||||
source=SlotSource.RULE_EXTRACTED.value,
|
||||
confidence=0.9,
|
||||
)
|
||||
|
||||
def _get_source_for_strategy(self, strategy: str | None) -> str:
|
||||
"""根据策略获取来源"""
|
||||
strategy_source_map = {
|
||||
"rule": SlotSource.RULE_EXTRACTED.value,
|
||||
"llm": SlotSource.LLM_INFERRED.value,
|
||||
"user_input": SlotSource.USER_CONFIRMED.value,
|
||||
}
|
||||
return strategy_source_map.get(strategy or "", "unknown")
|
||||
|
||||
def _get_confidence_for_source(self, source: str) -> float:
|
||||
"""根据来源获取置信度"""
|
||||
confidence_map = {
|
||||
SlotSource.USER_CONFIRMED.value: 1.0,
|
||||
SlotSource.RULE_EXTRACTED.value: 0.9,
|
||||
SlotSource.LLM_INFERRED.value: 0.7,
|
||||
}
|
||||
return confidence_map.get(source, 0.5)
|
||||
|
||||
|
||||
async def integrate_slot_extraction(
|
||||
session: AsyncSession,
|
||||
tenant_id: str,
|
||||
session_id: str,
|
||||
user_input: str,
|
||||
slot_state: SlotState | None = None,
|
||||
strategies: list[str] | None = None,
|
||||
) -> ExtractionResult:
|
||||
"""
|
||||
便捷函数:集成槽位提取
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
tenant_id: 租户 ID
|
||||
session_id: 会话 ID
|
||||
user_input: 用户输入
|
||||
slot_state: 当前槽位状态
|
||||
strategies: 提取策略链
|
||||
|
||||
Returns:
|
||||
ExtractionResult: 提取结果
|
||||
"""
|
||||
integration = SlotExtractionIntegration(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
return await integration.extract_and_fill(
|
||||
user_input=user_input,
|
||||
slot_state=slot_state,
|
||||
strategies=strategies,
|
||||
)
|
||||
|
|
@ -0,0 +1,379 @@
|
|||
"""
|
||||
Slot Manager Service.
|
||||
槽位管理服务 - 统一槽位写入入口,集成校验逻辑
|
||||
|
||||
职责:
|
||||
1. 在槽位值写入前执行校验
|
||||
2. 管理槽位值的来源和置信度
|
||||
3. 提供槽位写入的统一接口
|
||||
4. 返回校验失败时的追问提示
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.entities import SlotDefinition
|
||||
from app.models.mid.schemas import SlotSource
|
||||
from app.services.mid.slot_validation_service import (
|
||||
BatchValidationResult,
|
||||
SlotValidationError,
|
||||
SlotValidationService,
|
||||
)
|
||||
from app.services.slot_definition_service import SlotDefinitionService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SlotWriteResult:
|
||||
"""
|
||||
槽位写入结果
|
||||
|
||||
Attributes:
|
||||
success: 是否成功(校验通过并写入)
|
||||
slot_key: 槽位键名
|
||||
value: 最终写入的值
|
||||
error: 校验错误信息(校验失败时)
|
||||
ask_back_prompt: 追问提示语(校验失败时)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
success: bool,
|
||||
slot_key: str,
|
||||
value: Any | None = None,
|
||||
error: SlotValidationError | None = None,
|
||||
ask_back_prompt: str | None = None,
|
||||
):
|
||||
self.success = success
|
||||
self.slot_key = slot_key
|
||||
self.value = value
|
||||
self.error = error
|
||||
self.ask_back_prompt = ask_back_prompt
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
result = {
|
||||
"success": self.success,
|
||||
"slot_key": self.slot_key,
|
||||
"value": self.value,
|
||||
}
|
||||
if self.error:
|
||||
result["error"] = {
|
||||
"error_code": self.error.error_code,
|
||||
"error_message": self.error.error_message,
|
||||
}
|
||||
if self.ask_back_prompt:
|
||||
result["ask_back_prompt"] = self.ask_back_prompt
|
||||
return result
|
||||
|
||||
|
||||
class SlotManager:
|
||||
"""
|
||||
槽位管理器
|
||||
|
||||
统一槽位写入入口,在写入前执行校验。
|
||||
支持从 SlotDefinition 加载校验规则并执行。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession | None = None,
|
||||
validation_service: SlotValidationService | None = None,
|
||||
slot_def_service: SlotDefinitionService | None = None,
|
||||
):
|
||||
"""
|
||||
初始化槽位管理器
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
validation_service: 校验服务实例
|
||||
slot_def_service: 槽位定义服务实例
|
||||
"""
|
||||
self._session = session
|
||||
self._validation_service = validation_service or SlotValidationService()
|
||||
self._slot_def_service = slot_def_service
|
||||
self._slot_def_cache: dict[str, SlotDefinition | None] = {}
|
||||
|
||||
async def write_slot(
|
||||
self,
|
||||
tenant_id: str,
|
||||
slot_key: str,
|
||||
value: Any,
|
||||
source: SlotSource = SlotSource.USER_CONFIRMED,
|
||||
confidence: float = 1.0,
|
||||
skip_validation: bool = False,
|
||||
) -> SlotWriteResult:
|
||||
"""
|
||||
写入单个槽位值(带校验)
|
||||
|
||||
执行流程:
|
||||
1. 加载槽位定义
|
||||
2. 执行校验(如果未跳过)
|
||||
3. 返回校验结果
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
slot_key: 槽位键名
|
||||
value: 槽位值
|
||||
source: 值来源
|
||||
confidence: 置信度
|
||||
skip_validation: 是否跳过校验(用于特殊场景)
|
||||
|
||||
Returns:
|
||||
SlotWriteResult: 写入结果
|
||||
"""
|
||||
# 加载槽位定义
|
||||
slot_def = await self._get_slot_definition(tenant_id, slot_key)
|
||||
|
||||
# 如果没有定义且非跳过校验,允许写入(动态槽位)
|
||||
if slot_def is None and skip_validation:
|
||||
logger.debug(
|
||||
f"[SlotManager] Writing slot without definition: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}"
|
||||
)
|
||||
return SlotWriteResult(
|
||||
success=True,
|
||||
slot_key=slot_key,
|
||||
value=value,
|
||||
)
|
||||
|
||||
if slot_def is None:
|
||||
# 未定义槽位,允许写入但记录日志
|
||||
logger.info(
|
||||
f"[SlotManager] Slot definition not found, allowing write: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}"
|
||||
)
|
||||
return SlotWriteResult(
|
||||
success=True,
|
||||
slot_key=slot_key,
|
||||
value=value,
|
||||
)
|
||||
|
||||
# 执行校验
|
||||
if not skip_validation:
|
||||
validation_result = self._validation_service.validate_slot_value(
|
||||
slot_def, value, tenant_id
|
||||
)
|
||||
|
||||
if not validation_result.ok:
|
||||
logger.info(
|
||||
f"[SlotManager] Slot validation failed: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}, "
|
||||
f"error_code={validation_result.error_code}"
|
||||
)
|
||||
return SlotWriteResult(
|
||||
success=False,
|
||||
slot_key=slot_key,
|
||||
error=SlotValidationError(
|
||||
slot_key=slot_key,
|
||||
error_code=validation_result.error_code or "VALIDATION_FAILED",
|
||||
error_message=validation_result.error_message or "校验失败",
|
||||
ask_back_prompt=validation_result.ask_back_prompt,
|
||||
),
|
||||
ask_back_prompt=validation_result.ask_back_prompt,
|
||||
)
|
||||
|
||||
# 使用归一化后的值
|
||||
value = validation_result.normalized_value
|
||||
|
||||
logger.debug(
|
||||
f"[SlotManager] Slot validation passed: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}"
|
||||
)
|
||||
|
||||
return SlotWriteResult(
|
||||
success=True,
|
||||
slot_key=slot_key,
|
||||
value=value,
|
||||
)
|
||||
|
||||
async def write_slots(
|
||||
self,
|
||||
tenant_id: str,
|
||||
values: dict[str, Any],
|
||||
source: SlotSource = SlotSource.USER_CONFIRMED,
|
||||
confidence: float = 1.0,
|
||||
skip_validation: bool = False,
|
||||
) -> BatchValidationResult:
|
||||
"""
|
||||
批量写入槽位值(带校验)
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
values: 槽位值字典 {slot_key: value}
|
||||
source: 值来源
|
||||
confidence: 置信度
|
||||
skip_validation: 是否跳过校验
|
||||
|
||||
Returns:
|
||||
BatchValidationResult: 批量校验结果
|
||||
"""
|
||||
if skip_validation:
|
||||
return BatchValidationResult(
|
||||
ok=True,
|
||||
validated_values=values,
|
||||
)
|
||||
|
||||
# 加载所有相关槽位定义
|
||||
slot_defs = await self._get_slot_definitions(tenant_id, list(values.keys()))
|
||||
|
||||
# 执行批量校验
|
||||
result = self._validation_service.validate_slots(
|
||||
slot_defs, values, tenant_id
|
||||
)
|
||||
|
||||
if not result.ok:
|
||||
logger.info(
|
||||
f"[SlotManager] Batch slot validation failed: "
|
||||
f"tenant_id={tenant_id}, errors={[e.slot_key for e in result.errors]}"
|
||||
)
|
||||
else:
|
||||
logger.debug(
|
||||
f"[SlotManager] Batch slot validation passed: "
|
||||
f"tenant_id={tenant_id}, slots={list(values.keys())}"
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def validate_before_write(
|
||||
self,
|
||||
tenant_id: str,
|
||||
slot_key: str,
|
||||
value: Any,
|
||||
) -> tuple[bool, SlotValidationError | None]:
|
||||
"""
|
||||
在写入前预校验槽位值
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
slot_key: 槽位键名
|
||||
value: 槽位值
|
||||
|
||||
Returns:
|
||||
Tuple of (是否通过, 错误信息)
|
||||
"""
|
||||
slot_def = await self._get_slot_definition(tenant_id, slot_key)
|
||||
|
||||
if slot_def is None:
|
||||
# 未定义槽位,视为通过
|
||||
return True, None
|
||||
|
||||
result = self._validation_service.validate_slot_value(
|
||||
slot_def, value, tenant_id
|
||||
)
|
||||
|
||||
if result.ok:
|
||||
return True, None
|
||||
|
||||
return False, SlotValidationError(
|
||||
slot_key=slot_key,
|
||||
error_code=result.error_code or "VALIDATION_FAILED",
|
||||
error_message=result.error_message or "校验失败",
|
||||
ask_back_prompt=result.ask_back_prompt,
|
||||
)
|
||||
|
||||
async def get_ask_back_prompt(
|
||||
self,
|
||||
tenant_id: str,
|
||||
slot_key: str,
|
||||
) -> str | None:
|
||||
"""
|
||||
获取槽位的追问提示语
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
slot_key: 槽位键名
|
||||
|
||||
Returns:
|
||||
追问提示语或 None
|
||||
"""
|
||||
slot_def = await self._get_slot_definition(tenant_id, slot_key)
|
||||
if slot_def is None:
|
||||
return None
|
||||
|
||||
if isinstance(slot_def, SlotDefinition):
|
||||
return slot_def.ask_back_prompt
|
||||
return slot_def.get("ask_back_prompt")
|
||||
|
||||
async def _get_slot_definition(
|
||||
self,
|
||||
tenant_id: str,
|
||||
slot_key: str,
|
||||
) -> SlotDefinition | dict[str, Any] | None:
|
||||
"""
|
||||
获取槽位定义(带缓存)
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
slot_key: 槽位键名
|
||||
|
||||
Returns:
|
||||
槽位定义或 None
|
||||
"""
|
||||
cache_key = f"{tenant_id}:{slot_key}"
|
||||
|
||||
if cache_key in self._slot_def_cache:
|
||||
return self._slot_def_cache[cache_key]
|
||||
|
||||
slot_def = None
|
||||
if self._slot_def_service:
|
||||
slot_def = await self._slot_def_service.get_slot_definition_by_key(
|
||||
tenant_id, slot_key
|
||||
)
|
||||
elif self._session:
|
||||
service = SlotDefinitionService(self._session)
|
||||
slot_def = await service.get_slot_definition_by_key(tenant_id, slot_key)
|
||||
|
||||
self._slot_def_cache[cache_key] = slot_def
|
||||
return slot_def
|
||||
|
||||
async def _get_slot_definitions(
|
||||
self,
|
||||
tenant_id: str,
|
||||
slot_keys: list[str],
|
||||
) -> list[SlotDefinition | dict[str, Any]]:
|
||||
"""
|
||||
批量获取槽位定义
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
slot_keys: 槽位键名列表
|
||||
|
||||
Returns:
|
||||
槽位定义列表
|
||||
"""
|
||||
slot_defs = []
|
||||
for key in slot_keys:
|
||||
slot_def = await self._get_slot_definition(tenant_id, key)
|
||||
if slot_def:
|
||||
slot_defs.append(slot_def)
|
||||
return slot_defs
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清除槽位定义缓存"""
|
||||
self._slot_def_cache.clear()
|
||||
|
||||
|
||||
def create_slot_manager(
|
||||
session: AsyncSession | None = None,
|
||||
validation_service: SlotValidationService | None = None,
|
||||
slot_def_service: SlotDefinitionService | None = None,
|
||||
) -> SlotManager:
|
||||
"""
|
||||
创建槽位管理器实例
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
validation_service: 校验服务实例
|
||||
slot_def_service: 槽位定义服务实例
|
||||
|
||||
Returns:
|
||||
SlotManager: 槽位管理器实例
|
||||
"""
|
||||
return SlotManager(
|
||||
session=session,
|
||||
validation_service=validation_service,
|
||||
slot_def_service=slot_def_service,
|
||||
)
|
||||
|
|
@ -0,0 +1,562 @@
|
|||
"""
|
||||
Slot State Aggregator Service.
|
||||
槽位状态聚合服务 - 统一维护本轮槽位状态
|
||||
|
||||
职责:
|
||||
1. 聚合来自 memory_recall 的槽位值
|
||||
2. 叠加本轮输入的槽位值
|
||||
3. 识别缺失的必填槽位
|
||||
4. 支持槽位与元数据字段的关联映射
|
||||
5. 为 KB 检索过滤提供统一的槽位值来源
|
||||
6. [AC-MRS-SLOT-CACHE-01] 多轮状态持久化
|
||||
|
||||
[AC-MRS-SLOT-META-01] 槽位与元数据关联机制
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.mid.schemas import MemorySlot, SlotSource
|
||||
from app.services.cache.slot_state_cache import (
|
||||
CachedSlotState,
|
||||
CachedSlotValue,
|
||||
get_slot_state_cache,
|
||||
)
|
||||
from app.services.mid.slot_manager import SlotManager
|
||||
from app.services.slot_definition_service import SlotDefinitionService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlotState:
|
||||
"""
|
||||
槽位状态聚合结果
|
||||
|
||||
Attributes:
|
||||
filled_slots: 已填充的槽位值字典 {slot_key: value}
|
||||
missing_required_slots: 缺失的必填槽位列表
|
||||
slot_sources: 槽位值来源字典 {slot_key: source}
|
||||
slot_confidence: 槽位置信度字典 {slot_key: confidence}
|
||||
slot_to_field_map: 槽位到元数据字段的映射 {slot_key: field_key}
|
||||
"""
|
||||
filled_slots: dict[str, Any] = field(default_factory=dict)
|
||||
missing_required_slots: list[dict[str, str]] = field(default_factory=list)
|
||||
slot_sources: dict[str, str] = field(default_factory=dict)
|
||||
slot_confidence: dict[str, float] = field(default_factory=dict)
|
||||
slot_to_field_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def get_value_for_filter(self, field_key: str) -> Any:
|
||||
"""
|
||||
获取用于 KB 过滤的字段值
|
||||
|
||||
优先从 slot_to_field_map 反向查找
|
||||
"""
|
||||
# 直接匹配
|
||||
if field_key in self.filled_slots:
|
||||
return self.filled_slots[field_key]
|
||||
|
||||
# 通过 slot_to_field_map 反向查找
|
||||
for slot_key, mapped_field_key in self.slot_to_field_map.items():
|
||||
if mapped_field_key == field_key and slot_key in self.filled_slots:
|
||||
return self.filled_slots[slot_key]
|
||||
|
||||
return None
|
||||
|
||||
def to_debug_info(self) -> dict[str, Any]:
|
||||
"""转换为调试信息字典"""
|
||||
return {
|
||||
"filled_slots": self.filled_slots,
|
||||
"missing_required_slots": self.missing_required_slots,
|
||||
"slot_sources": self.slot_sources,
|
||||
"slot_to_field_map": self.slot_to_field_map,
|
||||
}
|
||||
|
||||
|
||||
class SlotStateAggregator:
|
||||
"""
|
||||
[AC-MRS-SLOT-META-01] 槽位状态聚合器
|
||||
|
||||
统一维护本轮槽位状态,支持:
|
||||
- 从 memory_recall 初始化槽位
|
||||
- 叠加本轮输入的槽位值
|
||||
- 识别缺失的必填槽位
|
||||
- 建立槽位与元数据字段的关联
|
||||
- [AC-MRS-SLOT-CACHE-01] 多轮状态持久化
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
session: AsyncSession,
|
||||
tenant_id: str,
|
||||
slot_manager: SlotManager | None = None,
|
||||
session_id: str | None = None,
|
||||
):
|
||||
self._session = session
|
||||
self._tenant_id = tenant_id
|
||||
self._session_id = session_id
|
||||
self._slot_manager = slot_manager or SlotManager(session=session)
|
||||
self._slot_def_service = SlotDefinitionService(session)
|
||||
self._cache = get_slot_state_cache()
|
||||
|
||||
async def aggregate(
|
||||
self,
|
||||
memory_slots: dict[str, MemorySlot] | None = None,
|
||||
current_input_slots: dict[str, Any] | None = None,
|
||||
context: dict[str, Any] | None = None,
|
||||
use_cache: bool = True,
|
||||
scene_slot_context: Any = None, # [AC-SCENE-SLOT-02] 场景槽位上下文
|
||||
) -> SlotState:
|
||||
"""
|
||||
聚合槽位状态
|
||||
|
||||
执行流程:
|
||||
1. [AC-MRS-SLOT-CACHE-01] 从缓存加载已有状态
|
||||
2. 从 memory_slots 初始化已填充槽位
|
||||
3. 叠加 current_input_slots(优先级更高)
|
||||
4. 从 context 提取槽位值
|
||||
5. 识别缺失的必填槽位
|
||||
6. 建立槽位与元数据字段的关联映射
|
||||
7. [AC-MRS-SLOT-CACHE-01] 回写缓存
|
||||
|
||||
Args:
|
||||
memory_slots: 从 memory_recall 召回的槽位
|
||||
current_input_slots: 本轮输入的槽位值
|
||||
context: 上下文信息,可能包含槽位值
|
||||
use_cache: 是否使用缓存(默认 True)
|
||||
|
||||
Returns:
|
||||
SlotState: 聚合后的槽位状态
|
||||
"""
|
||||
state = SlotState()
|
||||
|
||||
# [AC-MRS-SLOT-CACHE-01] 1. 从缓存加载已有状态
|
||||
cached_state = None
|
||||
if use_cache and self._session_id:
|
||||
cached_state = await self._cache.get(self._tenant_id, self._session_id)
|
||||
if cached_state:
|
||||
for slot_key, cached_value in cached_state.filled_slots.items():
|
||||
state.filled_slots[slot_key] = cached_value.value
|
||||
state.slot_sources[slot_key] = cached_value.source
|
||||
state.slot_confidence[slot_key] = cached_value.confidence
|
||||
state.slot_to_field_map = cached_state.slot_to_field_map.copy()
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-CACHE-01] Loaded from cache: "
|
||||
f"tenant={self._tenant_id}, session={self._session_id}, "
|
||||
f"slots={list(state.filled_slots.keys())}"
|
||||
)
|
||||
|
||||
# 2. 从 memory_slots 初始化
|
||||
if memory_slots:
|
||||
for slot_key, memory_slot in memory_slots.items():
|
||||
state.filled_slots[slot_key] = memory_slot.value
|
||||
state.slot_sources[slot_key] = memory_slot.source.value
|
||||
state.slot_confidence[slot_key] = memory_slot.confidence
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-META-01] Initialized from memory: "
|
||||
f"tenant={self._tenant_id}, slots={list(memory_slots.keys())}"
|
||||
)
|
||||
|
||||
# 3. 叠加本轮输入(优先级更高)
|
||||
if current_input_slots:
|
||||
for slot_key, value in current_input_slots.items():
|
||||
if value is not None:
|
||||
state.filled_slots[slot_key] = value
|
||||
state.slot_sources[slot_key] = SlotSource.USER_CONFIRMED.value
|
||||
state.slot_confidence[slot_key] = 1.0
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-META-01] Merged current input: "
|
||||
f"tenant={self._tenant_id}, slots={list(current_input_slots.keys())}"
|
||||
)
|
||||
|
||||
# 4. 从 context 提取槽位值(优先级最低)
|
||||
if context:
|
||||
context_slots = self._extract_slots_from_context(context)
|
||||
for slot_key, value in context_slots.items():
|
||||
if slot_key not in state.filled_slots and value is not None:
|
||||
state.filled_slots[slot_key] = value
|
||||
state.slot_sources[slot_key] = "context"
|
||||
state.slot_confidence[slot_key] = 0.5
|
||||
|
||||
# 5. 加载槽位定义并建立关联
|
||||
await self._build_slot_mappings(state)
|
||||
|
||||
# 6. 识别缺失的必填槽位
|
||||
await self._identify_missing_required_slots(state, scene_slot_context)
|
||||
|
||||
# [AC-MRS-SLOT-CACHE-01] 7. 回写缓存
|
||||
if use_cache and self._session_id:
|
||||
await self._save_to_cache(state)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-META-01] Slot state aggregated: "
|
||||
f"tenant={self._tenant_id}, filled={len(state.filled_slots)}, "
|
||||
f"missing={len(state.missing_required_slots)}"
|
||||
)
|
||||
|
||||
return state
|
||||
|
||||
async def _save_to_cache(self, state: SlotState) -> None:
|
||||
"""
|
||||
[AC-MRS-SLOT-CACHE-01] 保存槽位状态到缓存
|
||||
"""
|
||||
if not self._session_id:
|
||||
return
|
||||
|
||||
cached_slots = {}
|
||||
for slot_key, value in state.filled_slots.items():
|
||||
source = state.slot_sources.get(slot_key, "unknown")
|
||||
confidence = state.slot_confidence.get(slot_key, 1.0)
|
||||
cached_slots[slot_key] = CachedSlotValue(
|
||||
value=value,
|
||||
source=source,
|
||||
confidence=confidence,
|
||||
)
|
||||
|
||||
cached_state = CachedSlotState(
|
||||
filled_slots=cached_slots,
|
||||
slot_to_field_map=state.slot_to_field_map.copy(),
|
||||
)
|
||||
|
||||
await self._cache.set(self._tenant_id, self._session_id, cached_state)
|
||||
logger.debug(
|
||||
f"[AC-MRS-SLOT-CACHE-01] Saved to cache: "
|
||||
f"tenant={self._tenant_id}, session={self._session_id}"
|
||||
)
|
||||
|
||||
async def update_slot(
|
||||
self,
|
||||
slot_key: str,
|
||||
value: Any,
|
||||
source: str = "user_confirmed",
|
||||
confidence: float = 1.0,
|
||||
) -> SlotState | None:
|
||||
"""
|
||||
[AC-MRS-SLOT-CACHE-01] 更新单个槽位值并保存到缓存
|
||||
|
||||
Args:
|
||||
slot_key: 槽位键名
|
||||
value: 槽位值
|
||||
source: 值来源
|
||||
confidence: 置信度
|
||||
|
||||
Returns:
|
||||
更新后的槽位状态,如果没有 session_id 则返回 None
|
||||
"""
|
||||
if not self._session_id:
|
||||
return None
|
||||
|
||||
cached_value = CachedSlotValue(
|
||||
value=value,
|
||||
source=source,
|
||||
confidence=confidence,
|
||||
)
|
||||
|
||||
cached_state = await self._cache.merge_and_set(
|
||||
tenant_id=self._tenant_id,
|
||||
session_id=self._session_id,
|
||||
new_slots={slot_key: cached_value},
|
||||
)
|
||||
|
||||
state = SlotState()
|
||||
state.filled_slots = cached_state.get_simple_filled_slots()
|
||||
state.slot_sources = cached_state.get_slot_sources()
|
||||
state.slot_confidence = cached_state.get_slot_confidence()
|
||||
state.slot_to_field_map = cached_state.slot_to_field_map.copy()
|
||||
|
||||
await self._identify_missing_required_slots(state)
|
||||
|
||||
return state
|
||||
|
||||
async def clear_slot(self, slot_key: str) -> bool:
|
||||
"""
|
||||
[AC-MRS-SLOT-CACHE-01] 清除单个槽位值
|
||||
|
||||
Args:
|
||||
slot_key: 槽位键名
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not self._session_id:
|
||||
return False
|
||||
|
||||
return await self._cache.clear_slot(
|
||||
tenant_id=self._tenant_id,
|
||||
session_id=self._session_id,
|
||||
slot_key=slot_key,
|
||||
)
|
||||
|
||||
async def clear_all_slots(self) -> bool:
|
||||
"""
|
||||
[AC-MRS-SLOT-CACHE-01] 清除所有槽位状态
|
||||
|
||||
Returns:
|
||||
是否成功
|
||||
"""
|
||||
if not self._session_id:
|
||||
return False
|
||||
|
||||
return await self._cache.delete(
|
||||
tenant_id=self._tenant_id,
|
||||
session_id=self._session_id,
|
||||
)
|
||||
|
||||
def _extract_slots_from_context(self, context: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
从上下文中提取可能的槽位值
|
||||
|
||||
常见的槽位值可能存在于:
|
||||
- scene
|
||||
- product_line
|
||||
- region
|
||||
- grade
|
||||
等字段
|
||||
"""
|
||||
slots = {}
|
||||
slot_candidates = [
|
||||
"scene", "product_line", "region", "grade",
|
||||
"category", "type", "status", "priority"
|
||||
]
|
||||
|
||||
for key in slot_candidates:
|
||||
if key in context and context[key] is not None:
|
||||
slots[key] = context[key]
|
||||
|
||||
return slots
|
||||
|
||||
async def _build_slot_mappings(self, state: SlotState) -> None:
|
||||
"""
|
||||
建立槽位与元数据字段的关联映射
|
||||
|
||||
通过 linked_field_id 关联 SlotDefinition 和 MetadataFieldDefinition
|
||||
"""
|
||||
try:
|
||||
# 获取所有槽位定义
|
||||
slot_defs = await self._slot_def_service.list_slot_definitions(
|
||||
tenant_id=self._tenant_id
|
||||
)
|
||||
|
||||
for slot_def in slot_defs:
|
||||
if slot_def.linked_field_id:
|
||||
# 获取关联的元数据字段
|
||||
from app.services.metadata_field_definition_service import (
|
||||
MetadataFieldDefinitionService
|
||||
)
|
||||
field_service = MetadataFieldDefinitionService(self._session)
|
||||
linked_field = await field_service.get_field_definition(
|
||||
tenant_id=self._tenant_id,
|
||||
field_id=str(slot_def.linked_field_id)
|
||||
)
|
||||
|
||||
if linked_field:
|
||||
state.slot_to_field_map[slot_def.slot_key] = linked_field.field_key
|
||||
|
||||
# 检查类型一致性并告警
|
||||
await self._check_type_consistency(
|
||||
slot_def, linked_field, state
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AC-MRS-SLOT-META-01] Failed to build slot mappings: {e}"
|
||||
)
|
||||
|
||||
async def _check_type_consistency(
|
||||
self,
|
||||
slot_def: Any,
|
||||
linked_field: Any,
|
||||
state: SlotState,
|
||||
) -> None:
|
||||
"""
|
||||
检查槽位与关联元数据字段的类型一致性
|
||||
|
||||
当不一致时记录告警日志(不强拦截)
|
||||
"""
|
||||
# 检查类型一致性
|
||||
if slot_def.type != linked_field.type:
|
||||
logger.warning(
|
||||
f"[AC-MRS-SLOT-META-01] Type mismatch: "
|
||||
f"slot='{slot_def.slot_key}' type={slot_def.type} vs "
|
||||
f"field='{linked_field.field_key}' type={linked_field.type}"
|
||||
)
|
||||
|
||||
# 检查 required 一致性
|
||||
if slot_def.required != linked_field.required:
|
||||
logger.warning(
|
||||
f"[AC-MRS-SLOT-META-01] Required mismatch: "
|
||||
f"slot='{slot_def.slot_key}' required={slot_def.required} vs "
|
||||
f"field='{linked_field.field_key}' required={linked_field.required}"
|
||||
)
|
||||
|
||||
# 检查 options 一致性(对于 enum/array_enum 类型)
|
||||
if slot_def.type in ["enum", "array_enum"]:
|
||||
slot_options = set(slot_def.options or [])
|
||||
field_options = set(linked_field.options or [])
|
||||
if slot_options != field_options:
|
||||
logger.warning(
|
||||
f"[AC-MRS-SLOT-META-01] Options mismatch: "
|
||||
f"slot='{slot_def.slot_key}' options={slot_options} vs "
|
||||
f"field='{linked_field.field_key}' options={field_options}"
|
||||
)
|
||||
|
||||
async def _identify_missing_required_slots(
|
||||
self,
|
||||
state: SlotState,
|
||||
scene_slot_context: Any = None, # [AC-SCENE-SLOT-02] 场景槽位上下文
|
||||
) -> None:
|
||||
"""
|
||||
识别缺失的必填槽位
|
||||
|
||||
基于 SlotDefinition 的 required 字段和 linked_field 的 required 字段
|
||||
[AC-SCENE-SLOT-02] 当有场景槽位上下文时,优先使用场景定义的必填槽位
|
||||
"""
|
||||
try:
|
||||
# [AC-SCENE-SLOT-02] 如果有场景槽位上下文,使用场景定义的必填槽位
|
||||
if scene_slot_context:
|
||||
scene_required_keys = set(scene_slot_context.get_required_slot_keys())
|
||||
logger.info(
|
||||
f"[AC-SCENE-SLOT-02] Using scene required slots: "
|
||||
f"scene={scene_slot_context.scene_key}, "
|
||||
f"required_keys={scene_required_keys}"
|
||||
)
|
||||
|
||||
# 获取场景中定义的所有槽位
|
||||
all_slot_defs = await self._slot_def_service.list_slot_definitions(
|
||||
tenant_id=self._tenant_id,
|
||||
)
|
||||
slot_def_map = {slot.slot_key: slot for slot in all_slot_defs}
|
||||
|
||||
for slot_key in scene_required_keys:
|
||||
if slot_key not in state.filled_slots:
|
||||
slot_def = slot_def_map.get(slot_key)
|
||||
ask_back_prompt = slot_def.ask_back_prompt if slot_def else None
|
||||
|
||||
if not ask_back_prompt:
|
||||
ask_back_prompt = f"请提供{slot_key}信息"
|
||||
|
||||
missing_info = {
|
||||
"slot_key": slot_key,
|
||||
"label": slot_key,
|
||||
"reason": "scene_required_slot_missing",
|
||||
"ask_back_prompt": ask_back_prompt,
|
||||
"scene": scene_slot_context.scene_key,
|
||||
}
|
||||
|
||||
if slot_def and slot_def.linked_field_id:
|
||||
missing_info["linked_field_id"] = str(slot_def.linked_field_id)
|
||||
|
||||
state.missing_required_slots.append(missing_info)
|
||||
|
||||
return
|
||||
|
||||
# 获取所有 required 的槽位定义
|
||||
required_slot_defs = await self._slot_def_service.list_slot_definitions(
|
||||
tenant_id=self._tenant_id,
|
||||
required=True,
|
||||
)
|
||||
|
||||
for slot_def in required_slot_defs:
|
||||
if slot_def.slot_key not in state.filled_slots:
|
||||
# 获取追问提示
|
||||
ask_back_prompt = slot_def.ask_back_prompt
|
||||
|
||||
# 如果没有配置追问提示,使用通用模板
|
||||
if not ask_back_prompt:
|
||||
ask_back_prompt = f"请提供{slot_def.slot_key}信息"
|
||||
|
||||
missing_info = {
|
||||
"slot_key": slot_def.slot_key,
|
||||
"label": slot_def.slot_key,
|
||||
"reason": "required_slot_missing",
|
||||
"ask_back_prompt": ask_back_prompt,
|
||||
}
|
||||
|
||||
# 如果有关联字段,使用字段的 label
|
||||
if slot_def.linked_field_id:
|
||||
from app.services.metadata_field_definition_service import (
|
||||
MetadataFieldDefinitionService
|
||||
)
|
||||
field_service = MetadataFieldDefinitionService(self._session)
|
||||
linked_field = await field_service.get_field_definition(
|
||||
tenant_id=self._tenant_id,
|
||||
field_id=str(slot_def.linked_field_id)
|
||||
)
|
||||
if linked_field:
|
||||
missing_info["label"] = linked_field.label
|
||||
missing_info["field_key"] = linked_field.field_key
|
||||
|
||||
state.missing_required_slots.append(missing_info)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MRS-SLOT-META-01] Missing required slot: "
|
||||
f"slot_key={slot_def.slot_key}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AC-MRS-SLOT-META-01] Failed to identify missing slots: {e}"
|
||||
)
|
||||
|
||||
async def generate_ask_back_response(
|
||||
self,
|
||||
state: SlotState,
|
||||
missing_slot_key: str | None = None,
|
||||
) -> str | None:
|
||||
"""
|
||||
生成追问响应文案
|
||||
|
||||
Args:
|
||||
state: 当前槽位状态
|
||||
missing_slot_key: 指定要追问的槽位键名,为 None 时追问第一个缺失槽位
|
||||
|
||||
Returns:
|
||||
追问文案或 None(如果没有缺失槽位)
|
||||
"""
|
||||
if not state.missing_required_slots:
|
||||
return None
|
||||
|
||||
missing_info = None
|
||||
|
||||
# 如果指定了槽位键名,查找对应的追问提示
|
||||
if missing_slot_key:
|
||||
for missing in state.missing_required_slots:
|
||||
if missing.get("slot_key") == missing_slot_key:
|
||||
missing_info = missing
|
||||
break
|
||||
else:
|
||||
# 使用第一个缺失槽位
|
||||
missing_info = state.missing_required_slots[0]
|
||||
|
||||
if missing_info is None:
|
||||
return None
|
||||
|
||||
# 优先使用配置的 ask_back_prompt
|
||||
ask_back_prompt = missing_info.get("ask_back_prompt")
|
||||
if ask_back_prompt:
|
||||
return ask_back_prompt
|
||||
|
||||
# 使用通用模板
|
||||
label = missing_info.get("label", missing_info.get("slot_key", "相关信息"))
|
||||
return f"为了更好地为您提供帮助,请告诉我您的{label}。"
|
||||
|
||||
|
||||
def create_slot_state_aggregator(
|
||||
session: AsyncSession,
|
||||
tenant_id: str,
|
||||
) -> SlotStateAggregator:
|
||||
"""
|
||||
创建槽位状态聚合器实例
|
||||
|
||||
Args:
|
||||
session: 数据库会话
|
||||
tenant_id: 租户 ID
|
||||
|
||||
Returns:
|
||||
SlotStateAggregator: 槽位状态聚合器实例
|
||||
"""
|
||||
return SlotStateAggregator(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
|
@ -0,0 +1,355 @@
|
|||
"""
|
||||
Slot Strategy Executor.
|
||||
[AC-MRS-07-UPGRADE] 槽位提取策略链执行器
|
||||
|
||||
按顺序执行提取策略链,直到成功提取并通过校验。
|
||||
支持失败分类和详细日志追踪。
|
||||
"""
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Any, Callable
|
||||
|
||||
from app.models.entities import ExtractFailureType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ExtractStrategyType(str, Enum):
|
||||
"""提取策略类型"""
|
||||
RULE = "rule"
|
||||
LLM = "llm"
|
||||
USER_INPUT = "user_input"
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyStepResult:
|
||||
"""单步策略执行结果"""
|
||||
strategy: str
|
||||
success: bool
|
||||
value: Any = None
|
||||
failure_type: ExtractFailureType | None = None
|
||||
failure_reason: str = ""
|
||||
execution_time_ms: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class StrategyChainResult:
|
||||
"""策略链执行结果"""
|
||||
slot_key: str
|
||||
success: bool
|
||||
final_value: Any = None
|
||||
final_strategy: str | None = None
|
||||
steps: list[StrategyStepResult] = field(default_factory=list)
|
||||
total_execution_time_ms: float = 0.0
|
||||
ask_back_prompt: str | None = None
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
"""转换为字典"""
|
||||
return {
|
||||
"slot_key": self.slot_key,
|
||||
"success": self.success,
|
||||
"final_value": self.final_value,
|
||||
"final_strategy": self.final_strategy,
|
||||
"steps": [
|
||||
{
|
||||
"strategy": step.strategy,
|
||||
"success": step.success,
|
||||
"value": step.value if step.success else None,
|
||||
"failure_type": step.failure_type.value if step.failure_type else None,
|
||||
"failure_reason": step.failure_reason,
|
||||
"execution_time_ms": step.execution_time_ms,
|
||||
}
|
||||
for step in self.steps
|
||||
],
|
||||
"total_execution_time_ms": self.total_execution_time_ms,
|
||||
"ask_back_prompt": self.ask_back_prompt,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractContext:
|
||||
"""提取上下文"""
|
||||
tenant_id: str
|
||||
slot_key: str
|
||||
user_input: str
|
||||
slot_type: str
|
||||
validation_rule: str | None = None
|
||||
history: list[dict[str, str]] | None = None
|
||||
session_id: str | None = None
|
||||
|
||||
|
||||
class SlotStrategyExecutor:
|
||||
"""
|
||||
[AC-MRS-07-UPGRADE] 槽位提取策略链执行器
|
||||
|
||||
职责:
|
||||
1. 按策略链顺序执行提取
|
||||
2. 某一步成功且校验通过 -> 停止并返回结果
|
||||
3. 当前策略失败 -> 记录失败原因,继续下一策略
|
||||
4. 全部失败 -> 返回结构化失败结果
|
||||
5. 提供可追踪日志(slot_key、strategy、reason)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rule_extractor: Callable[[ExtractContext], Any] | None = None,
|
||||
llm_extractor: Callable[[ExtractContext], Any] | None = None,
|
||||
user_input_extractor: Callable[[ExtractContext], Any] | None = None,
|
||||
):
|
||||
"""
|
||||
初始化执行器
|
||||
|
||||
Args:
|
||||
rule_extractor: 规则提取器函数
|
||||
llm_extractor: LLM提取器函数
|
||||
user_input_extractor: 用户输入提取器函数
|
||||
"""
|
||||
self._extractors: dict[str, Callable[[ExtractContext], Any]] = {
|
||||
ExtractStrategyType.RULE.value: rule_extractor or self._default_rule_extract,
|
||||
ExtractStrategyType.LLM.value: llm_extractor or self._default_llm_extract,
|
||||
ExtractStrategyType.USER_INPUT.value: user_input_extractor or self._default_user_input_extract,
|
||||
}
|
||||
|
||||
async def execute_chain(
|
||||
self,
|
||||
strategies: list[str],
|
||||
context: ExtractContext,
|
||||
ask_back_prompt: str | None = None,
|
||||
) -> StrategyChainResult:
|
||||
"""
|
||||
执行提取策略链
|
||||
|
||||
Args:
|
||||
strategies: 策略链,如 ["user_input", "rule", "llm"]
|
||||
context: 提取上下文
|
||||
ask_back_prompt: 追问提示语(全部失败时使用)
|
||||
|
||||
Returns:
|
||||
StrategyChainResult: 执行结果
|
||||
"""
|
||||
import time
|
||||
|
||||
start_time = time.time()
|
||||
steps: list[StrategyStepResult] = []
|
||||
|
||||
logger.info(
|
||||
f"[SlotStrategyExecutor] Starting strategy chain for slot '{context.slot_key}': "
|
||||
f"strategies={strategies}, tenant={context.tenant_id}"
|
||||
)
|
||||
|
||||
for idx, strategy in enumerate(strategies):
|
||||
step_start = time.time()
|
||||
|
||||
logger.info(
|
||||
f"[SlotStrategyExecutor] Executing step {idx + 1}/{len(strategies)}: "
|
||||
f"slot_key={context.slot_key}, strategy={strategy}"
|
||||
)
|
||||
|
||||
step_result = await self._execute_single_strategy(strategy, context)
|
||||
step_result.execution_time_ms = (time.time() - step_start) * 1000
|
||||
steps.append(step_result)
|
||||
|
||||
if step_result.success:
|
||||
total_time = (time.time() - start_time) * 1000
|
||||
logger.info(
|
||||
f"[SlotStrategyExecutor] Strategy chain succeeded at step {idx + 1}: "
|
||||
f"slot_key={context.slot_key}, strategy={strategy}, "
|
||||
f"total_time_ms={total_time:.2f}"
|
||||
)
|
||||
return StrategyChainResult(
|
||||
slot_key=context.slot_key,
|
||||
success=True,
|
||||
final_value=step_result.value,
|
||||
final_strategy=strategy,
|
||||
steps=steps,
|
||||
total_execution_time_ms=total_time,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"[SlotStrategyExecutor] Step {idx + 1} failed: "
|
||||
f"slot_key={context.slot_key}, strategy={strategy}, "
|
||||
f"failure_type={step_result.failure_type}, "
|
||||
f"reason={step_result.failure_reason}"
|
||||
)
|
||||
|
||||
# 全部策略失败
|
||||
total_time = (time.time() - start_time) * 1000
|
||||
logger.warning(
|
||||
f"[SlotStrategyExecutor] All strategies failed for slot '{context.slot_key}': "
|
||||
f"attempted={len(strategies)}, total_time_ms={total_time:.2f}"
|
||||
)
|
||||
|
||||
return StrategyChainResult(
|
||||
slot_key=context.slot_key,
|
||||
success=False,
|
||||
final_value=None,
|
||||
final_strategy=None,
|
||||
steps=steps,
|
||||
total_execution_time_ms=total_time,
|
||||
ask_back_prompt=ask_back_prompt,
|
||||
)
|
||||
|
||||
async def _execute_single_strategy(
|
||||
self,
|
||||
strategy: str,
|
||||
context: ExtractContext,
|
||||
) -> StrategyStepResult:
|
||||
"""
|
||||
执行单个提取策略
|
||||
|
||||
Args:
|
||||
strategy: 策略类型
|
||||
context: 提取上下文
|
||||
|
||||
Returns:
|
||||
StrategyStepResult: 单步执行结果
|
||||
"""
|
||||
extractor = self._extractors.get(strategy)
|
||||
|
||||
if not extractor:
|
||||
return StrategyStepResult(
|
||||
strategy=strategy,
|
||||
success=False,
|
||||
failure_type=ExtractFailureType.EXTRACT_RUNTIME_ERROR,
|
||||
failure_reason=f"Unknown strategy: {strategy}",
|
||||
)
|
||||
|
||||
try:
|
||||
# 执行提取
|
||||
value = await extractor(context)
|
||||
|
||||
# 检查结果是否为空
|
||||
if value is None or value == "":
|
||||
return StrategyStepResult(
|
||||
strategy=strategy,
|
||||
success=False,
|
||||
failure_type=ExtractFailureType.EXTRACT_EMPTY,
|
||||
failure_reason="Extracted value is empty",
|
||||
)
|
||||
|
||||
# 执行校验(如果有校验规则)
|
||||
if context.validation_rule:
|
||||
is_valid, error_msg = self._validate_value(value, context)
|
||||
if not is_valid:
|
||||
return StrategyStepResult(
|
||||
strategy=strategy,
|
||||
success=False,
|
||||
failure_type=ExtractFailureType.EXTRACT_VALIDATION_FAIL,
|
||||
failure_reason=f"Validation failed: {error_msg}",
|
||||
)
|
||||
|
||||
return StrategyStepResult(
|
||||
strategy=strategy,
|
||||
success=True,
|
||||
value=value,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"[SlotStrategyExecutor] Runtime error in strategy '{strategy}' "
|
||||
f"for slot '{context.slot_key}': {e}"
|
||||
)
|
||||
return StrategyStepResult(
|
||||
strategy=strategy,
|
||||
success=False,
|
||||
failure_type=ExtractFailureType.EXTRACT_RUNTIME_ERROR,
|
||||
failure_reason=f"Runtime error: {str(e)}",
|
||||
)
|
||||
|
||||
def _validate_value(self, value: Any, context: ExtractContext) -> tuple[bool, str]:
|
||||
"""
|
||||
校验提取的值
|
||||
|
||||
Args:
|
||||
value: 提取的值
|
||||
context: 提取上下文
|
||||
|
||||
Returns:
|
||||
Tuple of (是否通过, 错误信息)
|
||||
"""
|
||||
import re
|
||||
|
||||
validation_rule = context.validation_rule
|
||||
if not validation_rule:
|
||||
return True, ""
|
||||
|
||||
try:
|
||||
# 尝试作为正则表达式校验
|
||||
if validation_rule.startswith("^") or validation_rule.endswith("$"):
|
||||
if re.match(validation_rule, str(value)):
|
||||
return True, ""
|
||||
return False, f"Value '{value}' does not match pattern '{validation_rule}'"
|
||||
|
||||
# 其他校验规则可以在这里扩展
|
||||
return True, ""
|
||||
|
||||
except re.error as e:
|
||||
logger.warning(f"[SlotStrategyExecutor] Invalid validation rule pattern: {e}")
|
||||
return True, "" # 正则错误时放行
|
||||
|
||||
async def _default_rule_extract(self, context: ExtractContext) -> Any:
|
||||
"""默认规则提取实现(占位)"""
|
||||
# 实际项目中应该调用 VariableExtractor 或其他规则引擎
|
||||
logger.debug(f"[SlotStrategyExecutor] Default rule extract for '{context.slot_key}'")
|
||||
return None
|
||||
|
||||
async def _default_llm_extract(self, context: ExtractContext) -> Any:
|
||||
"""默认LLM提取实现(占位)"""
|
||||
logger.debug(f"[SlotStrategyExecutor] Default LLM extract for '{context.slot_key}'")
|
||||
return None
|
||||
|
||||
async def _default_user_input_extract(self, context: ExtractContext) -> Any:
|
||||
"""默认用户输入提取实现"""
|
||||
# user_input 策略通常表示需要向用户询问,这里返回空表示需要追问
|
||||
logger.debug(f"[SlotStrategyExecutor] User input required for '{context.slot_key}'")
|
||||
return None
|
||||
|
||||
|
||||
# 便捷函数
|
||||
async def execute_extract_strategies(
|
||||
strategies: list[str],
|
||||
tenant_id: str,
|
||||
slot_key: str,
|
||||
user_input: str,
|
||||
slot_type: str = "string",
|
||||
validation_rule: str | None = None,
|
||||
ask_back_prompt: str | None = None,
|
||||
history: list[dict[str, str]] | None = None,
|
||||
rule_extractor: Callable[[ExtractContext], Any] | None = None,
|
||||
llm_extractor: Callable[[ExtractContext], Any] | None = None,
|
||||
) -> StrategyChainResult:
|
||||
"""
|
||||
便捷函数:执行提取策略链
|
||||
|
||||
Args:
|
||||
strategies: 策略链
|
||||
tenant_id: 租户ID
|
||||
slot_key: 槽位键名
|
||||
user_input: 用户输入
|
||||
slot_type: 槽位类型
|
||||
validation_rule: 校验规则
|
||||
ask_back_prompt: 追问提示语
|
||||
history: 对话历史
|
||||
rule_extractor: 规则提取器
|
||||
llm_extractor: LLM提取器
|
||||
|
||||
Returns:
|
||||
StrategyChainResult: 执行结果
|
||||
"""
|
||||
executor = SlotStrategyExecutor(
|
||||
rule_extractor=rule_extractor,
|
||||
llm_extractor=llm_extractor,
|
||||
)
|
||||
|
||||
context = ExtractContext(
|
||||
tenant_id=tenant_id,
|
||||
slot_key=slot_key,
|
||||
user_input=user_input,
|
||||
slot_type=slot_type,
|
||||
validation_rule=validation_rule,
|
||||
history=history,
|
||||
)
|
||||
|
||||
return await executor.execute_chain(strategies, context, ask_back_prompt)
|
||||
|
|
@ -0,0 +1,572 @@
|
|||
"""
|
||||
Slot Validation Service.
|
||||
槽位校验规则 runtime 生效服务
|
||||
|
||||
提供槽位值的运行时校验能力,支持:
|
||||
1. 正则表达式校验
|
||||
2. JSON Schema 校验
|
||||
3. 类型校验(string/number/boolean/enum/array_enum)
|
||||
4. 必填校验
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import jsonschema
|
||||
from jsonschema.exceptions import ValidationError as JsonSchemaValidationError
|
||||
|
||||
from app.models.entities import SlotDefinition
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# 错误码定义
|
||||
class SlotValidationErrorCode:
|
||||
"""槽位校验错误码"""
|
||||
|
||||
SLOT_REQUIRED_MISSING = "SLOT_REQUIRED_MISSING"
|
||||
SLOT_TYPE_INVALID = "SLOT_TYPE_INVALID"
|
||||
SLOT_REGEX_MISMATCH = "SLOT_REGEX_MISMATCH"
|
||||
SLOT_JSON_SCHEMA_MISMATCH = "SLOT_JSON_SCHEMA_MISMATCH"
|
||||
SLOT_VALIDATION_RULE_INVALID = "SLOT_VALIDATION_RULE_INVALID"
|
||||
SLOT_ENUM_INVALID = "SLOT_ENUM_INVALID"
|
||||
SLOT_ARRAY_ENUM_INVALID = "SLOT_ARRAY_ENUM_INVALID"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""
|
||||
单个槽位校验结果
|
||||
|
||||
Attributes:
|
||||
ok: 校验是否通过
|
||||
normalized_value: 归一化后的值(如类型转换后)
|
||||
error_code: 错误码(校验失败时)
|
||||
error_message: 错误描述(校验失败时)
|
||||
ask_back_prompt: 追问提示语(校验失败且配置了 ask_back_prompt 时)
|
||||
"""
|
||||
|
||||
ok: bool
|
||||
normalized_value: Any | None = None
|
||||
error_code: str | None = None
|
||||
error_message: str | None = None
|
||||
ask_back_prompt: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SlotValidationError:
|
||||
"""
|
||||
槽位校验错误详情
|
||||
|
||||
Attributes:
|
||||
slot_key: 槽位键名
|
||||
error_code: 错误码
|
||||
error_message: 错误描述
|
||||
ask_back_prompt: 追问提示语
|
||||
"""
|
||||
|
||||
slot_key: str
|
||||
error_code: str
|
||||
error_message: str
|
||||
ask_back_prompt: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BatchValidationResult:
|
||||
"""
|
||||
批量槽位校验结果
|
||||
|
||||
Attributes:
|
||||
ok: 是否全部校验通过
|
||||
errors: 校验错误列表
|
||||
validated_values: 校验通过的值字典
|
||||
"""
|
||||
|
||||
ok: bool
|
||||
errors: list[SlotValidationError] = field(default_factory=list)
|
||||
validated_values: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class SlotValidationService:
|
||||
"""
|
||||
槽位校验服务
|
||||
|
||||
负责在槽位值写回前执行校验,支持:
|
||||
- 正则表达式校验
|
||||
- JSON Schema 校验
|
||||
- 类型校验
|
||||
- 必填校验
|
||||
"""
|
||||
|
||||
# 支持的槽位类型
|
||||
VALID_TYPES = ["string", "number", "boolean", "enum", "array_enum"]
|
||||
|
||||
def __init__(self):
|
||||
"""初始化槽位校验服务"""
|
||||
self._schema_cache: dict[str, dict] = {}
|
||||
|
||||
def validate_slot_value(
|
||||
self,
|
||||
slot_def: dict[str, Any] | SlotDefinition,
|
||||
value: Any,
|
||||
tenant_id: str | None = None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
校验单个槽位值
|
||||
|
||||
校验顺序:
|
||||
1. 必填校验(如果 required=true 且值为空)
|
||||
2. 类型校验
|
||||
3. validation_rule 校验(正则或 JSON Schema)
|
||||
|
||||
Args:
|
||||
slot_def: 槽位定义(dict 或 SlotDefinition 对象)
|
||||
value: 待校验的值
|
||||
tenant_id: 租户 ID(用于日志记录)
|
||||
|
||||
Returns:
|
||||
ValidationResult: 校验结果
|
||||
"""
|
||||
# 统一转换为 dict 处理
|
||||
if isinstance(slot_def, SlotDefinition):
|
||||
slot_dict = {
|
||||
"slot_key": slot_def.slot_key,
|
||||
"type": slot_def.type,
|
||||
"required": slot_def.required,
|
||||
"validation_rule": slot_def.validation_rule,
|
||||
"ask_back_prompt": slot_def.ask_back_prompt,
|
||||
}
|
||||
else:
|
||||
slot_dict = slot_def
|
||||
|
||||
slot_key = slot_dict.get("slot_key", "unknown")
|
||||
slot_type = slot_dict.get("type", "string")
|
||||
required = slot_dict.get("required", False)
|
||||
validation_rule = slot_dict.get("validation_rule")
|
||||
ask_back_prompt = slot_dict.get("ask_back_prompt")
|
||||
|
||||
# 1. 必填校验
|
||||
if required and self._is_empty_value(value):
|
||||
logger.info(
|
||||
f"[SlotValidation] Required slot missing: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}"
|
||||
)
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_REQUIRED_MISSING,
|
||||
error_message=f"槽位 '{slot_key}' 为必填项",
|
||||
ask_back_prompt=ask_back_prompt,
|
||||
)
|
||||
|
||||
# 如果值为空且非必填,跳过后续校验
|
||||
if self._is_empty_value(value):
|
||||
return ValidationResult(ok=True, normalized_value=value)
|
||||
|
||||
# 2. 类型校验
|
||||
type_result = self._validate_type(slot_dict, value, tenant_id)
|
||||
if not type_result.ok:
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=type_result.error_code,
|
||||
error_message=type_result.error_message,
|
||||
ask_back_prompt=ask_back_prompt,
|
||||
)
|
||||
|
||||
normalized_value = type_result.normalized_value
|
||||
|
||||
# 3. validation_rule 校验
|
||||
if validation_rule and str(validation_rule).strip():
|
||||
rule_result = self._validate_rule(
|
||||
slot_dict, normalized_value, tenant_id
|
||||
)
|
||||
if not rule_result.ok:
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=rule_result.error_code,
|
||||
error_message=rule_result.error_message,
|
||||
ask_back_prompt=ask_back_prompt,
|
||||
)
|
||||
normalized_value = rule_result.normalized_value or normalized_value
|
||||
|
||||
logger.debug(
|
||||
f"[SlotValidation] Slot validation passed: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}, type={slot_type}"
|
||||
)
|
||||
|
||||
return ValidationResult(ok=True, normalized_value=normalized_value)
|
||||
|
||||
def validate_slots(
|
||||
self,
|
||||
slot_defs: list[dict[str, Any] | SlotDefinition],
|
||||
values: dict[str, Any],
|
||||
tenant_id: str | None = None,
|
||||
) -> BatchValidationResult:
|
||||
"""
|
||||
批量校验多个槽位值
|
||||
|
||||
Args:
|
||||
slot_defs: 槽位定义列表
|
||||
values: 槽位值字典 {slot_key: value}
|
||||
tenant_id: 租户 ID(用于日志记录)
|
||||
|
||||
Returns:
|
||||
BatchValidationResult: 批量校验结果
|
||||
"""
|
||||
errors: list[SlotValidationError] = []
|
||||
validated_values: dict[str, Any] = {}
|
||||
|
||||
# 构建 slot_def 映射
|
||||
slot_def_map: dict[str, dict[str, Any] | SlotDefinition] = {}
|
||||
for slot_def in slot_defs:
|
||||
if isinstance(slot_def, SlotDefinition):
|
||||
slot_def_map[slot_def.slot_key] = slot_def
|
||||
else:
|
||||
slot_def_map[slot_def.get("slot_key", "")] = slot_def
|
||||
|
||||
# 校验每个提供的值
|
||||
for slot_key, value in values.items():
|
||||
slot_def = slot_def_map.get(slot_key)
|
||||
if not slot_def:
|
||||
# 未定义槽位,跳过校验(允许动态槽位)
|
||||
validated_values[slot_key] = value
|
||||
continue
|
||||
|
||||
result = self.validate_slot_value(slot_def, value, tenant_id)
|
||||
|
||||
if result.ok:
|
||||
validated_values[slot_key] = result.normalized_value
|
||||
else:
|
||||
errors.append(
|
||||
SlotValidationError(
|
||||
slot_key=slot_key,
|
||||
error_code=result.error_code or "UNKNOWN_ERROR",
|
||||
error_message=result.error_message or "校验失败",
|
||||
ask_back_prompt=result.ask_back_prompt,
|
||||
)
|
||||
)
|
||||
|
||||
# 检查必填槽位是否缺失
|
||||
for slot_def in slot_defs:
|
||||
if isinstance(slot_def, SlotDefinition):
|
||||
slot_key = slot_def.slot_key
|
||||
required = slot_def.required
|
||||
ask_back_prompt = slot_def.ask_back_prompt
|
||||
else:
|
||||
slot_key = slot_def.get("slot_key", "")
|
||||
required = slot_def.get("required", False)
|
||||
ask_back_prompt = slot_def.get("ask_back_prompt")
|
||||
|
||||
if required and slot_key not in values:
|
||||
# 检查是否已经有该错误
|
||||
if not any(e.slot_key == slot_key for e in errors):
|
||||
errors.append(
|
||||
SlotValidationError(
|
||||
slot_key=slot_key,
|
||||
error_code=SlotValidationErrorCode.SLOT_REQUIRED_MISSING,
|
||||
error_message=f"槽位 '{slot_key}' 为必填项",
|
||||
ask_back_prompt=ask_back_prompt,
|
||||
)
|
||||
)
|
||||
|
||||
return BatchValidationResult(
|
||||
ok=len(errors) == 0,
|
||||
errors=errors,
|
||||
validated_values=validated_values,
|
||||
)
|
||||
|
||||
def _is_empty_value(self, value: Any) -> bool:
|
||||
"""判断值是否为空"""
|
||||
if value is None:
|
||||
return True
|
||||
if isinstance(value, str) and not value.strip():
|
||||
return True
|
||||
if isinstance(value, list) and len(value) == 0:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _validate_type(
|
||||
self,
|
||||
slot_def: dict[str, Any],
|
||||
value: Any,
|
||||
tenant_id: str | None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
类型校验
|
||||
|
||||
Args:
|
||||
slot_def: 槽位定义字典
|
||||
value: 待校验的值
|
||||
tenant_id: 租户 ID
|
||||
|
||||
Returns:
|
||||
ValidationResult: 校验结果
|
||||
"""
|
||||
slot_key = slot_def.get("slot_key", "unknown")
|
||||
slot_type = slot_def.get("type", "string")
|
||||
|
||||
if slot_type not in self.VALID_TYPES:
|
||||
logger.warning(
|
||||
f"[SlotValidation] Unknown slot type: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}, type={slot_type}"
|
||||
)
|
||||
# 未知类型不阻止,只记录警告
|
||||
return ValidationResult(ok=True, normalized_value=value)
|
||||
|
||||
try:
|
||||
if slot_type == "string":
|
||||
if not isinstance(value, str):
|
||||
# 尝试转换为字符串
|
||||
normalized = str(value)
|
||||
return ValidationResult(ok=True, normalized_value=normalized)
|
||||
return ValidationResult(ok=True, normalized_value=value)
|
||||
|
||||
elif slot_type == "number":
|
||||
if isinstance(value, bool):
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 类型应为数字,但得到布尔值",
|
||||
)
|
||||
if isinstance(value, int | float):
|
||||
return ValidationResult(ok=True, normalized_value=value)
|
||||
# 尝试转换为数字
|
||||
if isinstance(value, str):
|
||||
try:
|
||||
if "." in value:
|
||||
normalized = float(value)
|
||||
else:
|
||||
normalized = int(value)
|
||||
return ValidationResult(ok=True, normalized_value=normalized)
|
||||
except ValueError:
|
||||
pass
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 类型应为数字",
|
||||
)
|
||||
|
||||
elif slot_type == "boolean":
|
||||
if isinstance(value, bool):
|
||||
return ValidationResult(ok=True, normalized_value=value)
|
||||
if isinstance(value, str):
|
||||
lower_val = value.lower()
|
||||
if lower_val in ("true", "1", "yes", "是", "真"):
|
||||
return ValidationResult(ok=True, normalized_value=True)
|
||||
if lower_val in ("false", "0", "no", "否", "假"):
|
||||
return ValidationResult(ok=True, normalized_value=False)
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 类型应为布尔值",
|
||||
)
|
||||
|
||||
elif slot_type == "enum":
|
||||
if not isinstance(value, str):
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 类型应为字符串(枚举)",
|
||||
)
|
||||
# 如果有选项定义,校验值是否在选项中
|
||||
options = slot_def.get("options") or []
|
||||
if options and value not in options:
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_ENUM_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 的值 '{value}' 不在允许选项 {options} 中",
|
||||
)
|
||||
return ValidationResult(ok=True, normalized_value=value)
|
||||
|
||||
elif slot_type == "array_enum":
|
||||
if not isinstance(value, list):
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 类型应为数组",
|
||||
)
|
||||
# 校验数组元素
|
||||
options = slot_def.get("options") or []
|
||||
for item in value:
|
||||
if not isinstance(item, str):
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_ARRAY_ENUM_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 的数组元素应为字符串",
|
||||
)
|
||||
if options and item not in options:
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_ARRAY_ENUM_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 的值 '{item}' 不在允许选项 {options} 中",
|
||||
)
|
||||
return ValidationResult(ok=True, normalized_value=value)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[SlotValidation] Type validation error: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}, error={e}"
|
||||
)
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_TYPE_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 类型校验异常: {str(e)}",
|
||||
)
|
||||
|
||||
return ValidationResult(ok=True, normalized_value=value)
|
||||
|
||||
def _validate_rule(
|
||||
self,
|
||||
slot_def: dict[str, Any],
|
||||
value: Any,
|
||||
tenant_id: str | None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
校验规则校验(正则或 JSON Schema)
|
||||
|
||||
判定逻辑:
|
||||
1. 如果规则是 JSON 对象字符串(以 { 或 [ 开头),按 JSON Schema 处理
|
||||
2. 否则按正则表达式处理
|
||||
|
||||
Args:
|
||||
slot_def: 槽位定义字典
|
||||
value: 待校验的值
|
||||
tenant_id: 租户 ID
|
||||
|
||||
Returns:
|
||||
ValidationResult: 校验结果
|
||||
"""
|
||||
slot_key = slot_def.get("slot_key", "unknown")
|
||||
validation_rule = str(slot_def.get("validation_rule", "")).strip()
|
||||
|
||||
if not validation_rule:
|
||||
return ValidationResult(ok=True, normalized_value=value)
|
||||
|
||||
# 判定是 JSON Schema 还是正则表达式
|
||||
# JSON Schema 通常以 { 或 [ 开头
|
||||
is_json_schema = validation_rule.strip().startswith(("{", "["))
|
||||
|
||||
if is_json_schema:
|
||||
return self._validate_json_schema(
|
||||
slot_key, validation_rule, value, tenant_id
|
||||
)
|
||||
else:
|
||||
return self._validate_regex(
|
||||
slot_key, validation_rule, value, tenant_id
|
||||
)
|
||||
|
||||
def _validate_regex(
|
||||
self,
|
||||
slot_key: str,
|
||||
pattern: str,
|
||||
value: Any,
|
||||
tenant_id: str | None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
正则表达式校验
|
||||
|
||||
Args:
|
||||
slot_key: 槽位键名
|
||||
pattern: 正则表达式
|
||||
value: 待校验的值
|
||||
tenant_id: 租户 ID
|
||||
|
||||
Returns:
|
||||
ValidationResult: 校验结果
|
||||
"""
|
||||
try:
|
||||
# 将值转为字符串进行匹配
|
||||
str_value = str(value) if value is not None else ""
|
||||
|
||||
if not re.search(pattern, str_value):
|
||||
logger.info(
|
||||
f"[SlotValidation] Regex mismatch: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}"
|
||||
)
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_REGEX_MISMATCH,
|
||||
error_message=f"槽位 '{slot_key}' 的值不符合格式要求",
|
||||
)
|
||||
|
||||
return ValidationResult(ok=True, normalized_value=value)
|
||||
|
||||
except re.error as e:
|
||||
logger.warning(
|
||||
f"[SlotValidation] Invalid regex pattern: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}, pattern={pattern}, error={e}"
|
||||
)
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_VALIDATION_RULE_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 的校验规则配置无效(非法正则)",
|
||||
)
|
||||
|
||||
def _validate_json_schema(
|
||||
self,
|
||||
slot_key: str,
|
||||
schema_str: str,
|
||||
value: Any,
|
||||
tenant_id: str | None,
|
||||
) -> ValidationResult:
|
||||
"""
|
||||
JSON Schema 校验
|
||||
|
||||
Args:
|
||||
slot_key: 槽位键名
|
||||
schema_str: JSON Schema 字符串
|
||||
value: 待校验的值
|
||||
tenant_id: 租户 ID
|
||||
|
||||
Returns:
|
||||
ValidationResult: 校验结果
|
||||
"""
|
||||
try:
|
||||
# 解析 JSON Schema
|
||||
schema = self._schema_cache.get(schema_str)
|
||||
if schema is None:
|
||||
schema = json.loads(schema_str)
|
||||
self._schema_cache[schema_str] = schema
|
||||
|
||||
# 执行校验
|
||||
jsonschema.validate(instance=value, schema=schema)
|
||||
return ValidationResult(ok=True, normalized_value=value)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
logger.warning(
|
||||
f"[SlotValidation] Invalid JSON schema: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}, error={e}"
|
||||
)
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_VALIDATION_RULE_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 的校验规则配置无效(非法 JSON)",
|
||||
)
|
||||
|
||||
except JsonSchemaValidationError as e:
|
||||
logger.info(
|
||||
f"[SlotValidation] JSON schema mismatch: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}, error={e.message}"
|
||||
)
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_JSON_SCHEMA_MISMATCH,
|
||||
error_message=f"槽位 '{slot_key}' 的值不符合格式要求: {e.message}",
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[SlotValidation] JSON schema validation error: "
|
||||
f"tenant_id={tenant_id}, slot_key={slot_key}, error={e}"
|
||||
)
|
||||
return ValidationResult(
|
||||
ok=False,
|
||||
error_code=SlotValidationErrorCode.SLOT_VALIDATION_RULE_INVALID,
|
||||
error_message=f"槽位 '{slot_key}' 的校验规则执行异常: {str(e)}",
|
||||
)
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
"""清除 JSON Schema 缓存"""
|
||||
self._schema_cache.clear()
|
||||
|
|
@ -0,0 +1,463 @@
|
|||
"""
|
||||
Scene Slot Bundle Service.
|
||||
[AC-SCENE-SLOT-01] 场景-槽位映射配置服务
|
||||
[AC-SCENE-SLOT-03] 支持缓存失效
|
||||
|
||||
职责:
|
||||
1. 场景槽位包的 CRUD 操作
|
||||
2. 槽位引用有效性校验
|
||||
3. 运行时场景槽位包查询
|
||||
4. 缓存失效管理
|
||||
"""
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.entities import (
|
||||
SceneSlotBundle,
|
||||
SceneSlotBundleCreate,
|
||||
SceneSlotBundleStatus,
|
||||
SceneSlotBundleUpdate,
|
||||
)
|
||||
from app.services.cache.scene_slot_bundle_cache import get_scene_slot_bundle_cache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SceneSlotBundleService:
|
||||
"""
|
||||
[AC-SCENE-SLOT-01] 场景槽位包服务
|
||||
"""
|
||||
|
||||
def __init__(self, session: AsyncSession):
|
||||
self._session = session
|
||||
|
||||
async def list_bundles(
|
||||
self,
|
||||
tenant_id: str,
|
||||
status: str | None = None,
|
||||
) -> list[SceneSlotBundle]:
|
||||
"""
|
||||
列出场景槽位包
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
status: 状态过滤(可选)
|
||||
|
||||
Returns:
|
||||
场景槽位包列表
|
||||
"""
|
||||
stmt = select(SceneSlotBundle).where(
|
||||
SceneSlotBundle.tenant_id == tenant_id
|
||||
)
|
||||
|
||||
if status:
|
||||
stmt = stmt.where(SceneSlotBundle.status == status)
|
||||
|
||||
stmt = stmt.order_by(SceneSlotBundle.created_at.desc())
|
||||
|
||||
result = await self._session.execute(stmt)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def get_bundle(
|
||||
self,
|
||||
tenant_id: str,
|
||||
bundle_id: str,
|
||||
) -> SceneSlotBundle | None:
|
||||
"""
|
||||
获取单个场景槽位包
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
bundle_id: 槽位包 ID
|
||||
|
||||
Returns:
|
||||
场景槽位包或 None
|
||||
"""
|
||||
try:
|
||||
stmt = select(SceneSlotBundle).where(
|
||||
SceneSlotBundle.tenant_id == tenant_id,
|
||||
SceneSlotBundle.id == uuid.UUID(bundle_id),
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
async def get_bundle_by_scene_key(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
) -> SceneSlotBundle | None:
|
||||
"""
|
||||
根据场景标识获取槽位包
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
|
||||
Returns:
|
||||
场景槽位包或 None
|
||||
"""
|
||||
stmt = select(SceneSlotBundle).where(
|
||||
SceneSlotBundle.tenant_id == tenant_id,
|
||||
SceneSlotBundle.scene_key == scene_key,
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def get_active_bundle_by_scene(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
) -> SceneSlotBundle | None:
|
||||
"""
|
||||
获取活跃状态的场景槽位包(运行时使用)
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
|
||||
Returns:
|
||||
活跃状态的场景槽位包或 None
|
||||
"""
|
||||
stmt = select(SceneSlotBundle).where(
|
||||
SceneSlotBundle.tenant_id == tenant_id,
|
||||
SceneSlotBundle.scene_key == scene_key,
|
||||
SceneSlotBundle.status == SceneSlotBundleStatus.ACTIVE.value,
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def create_bundle(
|
||||
self,
|
||||
tenant_id: str,
|
||||
bundle_create: SceneSlotBundleCreate,
|
||||
) -> SceneSlotBundle:
|
||||
"""
|
||||
创建场景槽位包
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
bundle_create: 创建请求数据
|
||||
|
||||
Returns:
|
||||
创建的场景槽位包
|
||||
|
||||
Raises:
|
||||
ValueError: 校验失败
|
||||
"""
|
||||
validation_errors = await self._validate_bundle_data(
|
||||
tenant_id=tenant_id,
|
||||
required_slots=bundle_create.required_slots,
|
||||
optional_slots=bundle_create.optional_slots,
|
||||
slot_priority=bundle_create.slot_priority,
|
||||
)
|
||||
|
||||
if validation_errors:
|
||||
raise ValueError("; ".join(validation_errors))
|
||||
|
||||
existing = await self.get_bundle_by_scene_key(
|
||||
tenant_id=tenant_id,
|
||||
scene_key=bundle_create.scene_key,
|
||||
)
|
||||
|
||||
if existing:
|
||||
raise ValueError(f"场景标识 '{bundle_create.scene_key}' 已存在")
|
||||
|
||||
bundle = SceneSlotBundle(
|
||||
tenant_id=tenant_id,
|
||||
scene_key=bundle_create.scene_key,
|
||||
scene_name=bundle_create.scene_name,
|
||||
description=bundle_create.description,
|
||||
required_slots=bundle_create.required_slots or [],
|
||||
optional_slots=bundle_create.optional_slots or [],
|
||||
slot_priority=bundle_create.slot_priority,
|
||||
completion_threshold=bundle_create.completion_threshold,
|
||||
ask_back_order=bundle_create.ask_back_order,
|
||||
status=bundle_create.status,
|
||||
)
|
||||
|
||||
self._session.add(bundle)
|
||||
await self._session.flush()
|
||||
|
||||
logger.info(
|
||||
f"[AC-SCENE-SLOT-01] Created scene slot bundle: "
|
||||
f"tenant={tenant_id}, scene_key={bundle.scene_key}, "
|
||||
f"required={len(bundle.required_slots)}, optional={len(bundle.optional_slots)}"
|
||||
)
|
||||
|
||||
return bundle
|
||||
|
||||
async def update_bundle(
|
||||
self,
|
||||
tenant_id: str,
|
||||
bundle_id: str,
|
||||
bundle_update: SceneSlotBundleUpdate,
|
||||
) -> SceneSlotBundle | None:
|
||||
"""
|
||||
更新场景槽位包
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
bundle_id: 槽位包 ID
|
||||
bundle_update: 更新请求数据
|
||||
|
||||
Returns:
|
||||
更新后的场景槽位包或 None
|
||||
|
||||
Raises:
|
||||
ValueError: 校验失败
|
||||
"""
|
||||
bundle = await self.get_bundle(tenant_id, bundle_id)
|
||||
|
||||
if not bundle:
|
||||
return None
|
||||
|
||||
required_slots = bundle_update.required_slots if bundle_update.required_slots is not None else bundle.required_slots
|
||||
optional_slots = bundle_update.optional_slots if bundle_update.optional_slots is not None else bundle.optional_slots
|
||||
slot_priority = bundle_update.slot_priority if bundle_update.slot_priority is not None else bundle.slot_priority
|
||||
|
||||
validation_errors = await self._validate_bundle_data(
|
||||
tenant_id=tenant_id,
|
||||
required_slots=required_slots,
|
||||
optional_slots=optional_slots,
|
||||
slot_priority=slot_priority,
|
||||
)
|
||||
|
||||
if validation_errors:
|
||||
raise ValueError("; ".join(validation_errors))
|
||||
|
||||
if bundle_update.scene_name is not None:
|
||||
bundle.scene_name = bundle_update.scene_name
|
||||
if bundle_update.description is not None:
|
||||
bundle.description = bundle_update.description
|
||||
if bundle_update.required_slots is not None:
|
||||
bundle.required_slots = bundle_update.required_slots
|
||||
if bundle_update.optional_slots is not None:
|
||||
bundle.optional_slots = bundle_update.optional_slots
|
||||
if bundle_update.slot_priority is not None:
|
||||
bundle.slot_priority = bundle_update.slot_priority
|
||||
if bundle_update.completion_threshold is not None:
|
||||
bundle.completion_threshold = bundle_update.completion_threshold
|
||||
if bundle_update.ask_back_order is not None:
|
||||
bundle.ask_back_order = bundle_update.ask_back_order
|
||||
if bundle_update.status is not None:
|
||||
bundle.status = bundle_update.status
|
||||
|
||||
bundle.version += 1
|
||||
|
||||
await self._session.flush()
|
||||
|
||||
# [AC-SCENE-SLOT-03] 使缓存失效
|
||||
await self._invalidate_cache(tenant_id, bundle.scene_key)
|
||||
|
||||
logger.info(
|
||||
f"[AC-SCENE-SLOT-01] Updated scene slot bundle: "
|
||||
f"tenant={tenant_id}, scene_key={bundle.scene_key}, version={bundle.version}"
|
||||
)
|
||||
|
||||
return bundle
|
||||
|
||||
async def delete_bundle(
|
||||
self,
|
||||
tenant_id: str,
|
||||
bundle_id: str,
|
||||
) -> bool:
|
||||
"""
|
||||
删除场景槽位包
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
bundle_id: 槽位包 ID
|
||||
|
||||
Returns:
|
||||
是否成功删除
|
||||
"""
|
||||
bundle = await self.get_bundle(tenant_id, bundle_id)
|
||||
|
||||
if not bundle:
|
||||
return False
|
||||
|
||||
scene_key = bundle.scene_key
|
||||
|
||||
await self._session.delete(bundle)
|
||||
await self._session.flush()
|
||||
|
||||
# [AC-SCENE-SLOT-03] 使缓存失效
|
||||
await self._invalidate_cache(tenant_id, scene_key)
|
||||
|
||||
logger.info(
|
||||
f"[AC-SCENE-SLOT-01] Deleted scene slot bundle: "
|
||||
f"tenant={tenant_id}, scene_key={scene_key}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def _validate_bundle_data(
|
||||
self,
|
||||
tenant_id: str,
|
||||
required_slots: list[str],
|
||||
optional_slots: list[str],
|
||||
slot_priority: list[str] | None,
|
||||
) -> list[str]:
|
||||
"""
|
||||
校验槽位包数据
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
required_slots: 必填槽位列表
|
||||
optional_slots: 可选槽位列表
|
||||
slot_priority: 优先级列表
|
||||
|
||||
Returns:
|
||||
错误信息列表(空列表表示校验通过)
|
||||
"""
|
||||
errors = []
|
||||
|
||||
required_set = set(required_slots or [])
|
||||
optional_set = set(optional_slots or [])
|
||||
|
||||
overlap = required_set & optional_set
|
||||
if overlap:
|
||||
errors.append(f"必填和可选槽位存在交叉: {list(overlap)}")
|
||||
|
||||
if slot_priority:
|
||||
priority_set = set(slot_priority)
|
||||
all_slots = required_set | optional_set
|
||||
|
||||
unknown_in_priority = priority_set - all_slots
|
||||
if unknown_in_priority:
|
||||
errors.append(f"优先级列表包含未定义的槽位: {list(unknown_in_priority)}")
|
||||
|
||||
valid_slots = await self._validate_slot_keys(
|
||||
tenant_id=tenant_id,
|
||||
slot_keys=list(required_set | optional_set),
|
||||
)
|
||||
|
||||
invalid_slots = (required_set | optional_set) - valid_slots
|
||||
if invalid_slots:
|
||||
errors.append(f"以下槽位不存在: {list(invalid_slots)}")
|
||||
|
||||
return errors
|
||||
|
||||
async def _validate_slot_keys(
|
||||
self,
|
||||
tenant_id: str,
|
||||
slot_keys: list[str],
|
||||
) -> set[str]:
|
||||
"""
|
||||
校验槽位键名是否存在
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
slot_keys: 待校验的槽位键名列表
|
||||
|
||||
Returns:
|
||||
有效的槽位键名集合
|
||||
"""
|
||||
if not slot_keys:
|
||||
return set()
|
||||
|
||||
from app.services.slot_definition_service import SlotDefinitionService
|
||||
|
||||
slot_service = SlotDefinitionService(self._session)
|
||||
all_slots = await slot_service.list_slot_definitions(tenant_id)
|
||||
|
||||
valid_keys = {slot.slot_key for slot in all_slots}
|
||||
|
||||
return set(slot_keys) & valid_keys
|
||||
|
||||
async def get_bundle_with_slot_details(
|
||||
self,
|
||||
tenant_id: str,
|
||||
bundle_id: str,
|
||||
) -> dict[str, Any] | None:
|
||||
"""
|
||||
获取场景槽位包及其槽位详情
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
bundle_id: 槽位包 ID
|
||||
|
||||
Returns:
|
||||
包含槽位详情的字典或 None
|
||||
"""
|
||||
bundle = await self.get_bundle(tenant_id, bundle_id)
|
||||
|
||||
if not bundle:
|
||||
return None
|
||||
|
||||
from app.services.slot_definition_service import SlotDefinitionService
|
||||
|
||||
slot_service = SlotDefinitionService(self._session)
|
||||
all_slots = await slot_service.list_slot_definitions(tenant_id)
|
||||
|
||||
slot_map = {slot.slot_key: slot for slot in all_slots}
|
||||
|
||||
required_slot_details = []
|
||||
for slot_key in bundle.required_slots:
|
||||
if slot_key in slot_map:
|
||||
slot = slot_map[slot_key]
|
||||
required_slot_details.append({
|
||||
"slot_key": slot.slot_key,
|
||||
"type": slot.type,
|
||||
"required": slot.required,
|
||||
"ask_back_prompt": slot.ask_back_prompt,
|
||||
"linked_field_id": str(slot.linked_field_id) if slot.linked_field_id else None,
|
||||
})
|
||||
|
||||
optional_slot_details = []
|
||||
for slot_key in bundle.optional_slots:
|
||||
if slot_key in slot_map:
|
||||
slot = slot_map[slot_key]
|
||||
optional_slot_details.append({
|
||||
"slot_key": slot.slot_key,
|
||||
"type": slot.type,
|
||||
"required": slot.required,
|
||||
"ask_back_prompt": slot.ask_back_prompt,
|
||||
"linked_field_id": str(slot.linked_field_id) if slot.linked_field_id else None,
|
||||
})
|
||||
|
||||
return {
|
||||
"id": str(bundle.id),
|
||||
"tenant_id": str(bundle.tenant_id),
|
||||
"scene_key": bundle.scene_key,
|
||||
"scene_name": bundle.scene_name,
|
||||
"description": bundle.description,
|
||||
"required_slots": bundle.required_slots,
|
||||
"optional_slots": bundle.optional_slots,
|
||||
"required_slot_details": required_slot_details,
|
||||
"optional_slot_details": optional_slot_details,
|
||||
"slot_priority": bundle.slot_priority,
|
||||
"completion_threshold": bundle.completion_threshold,
|
||||
"ask_back_order": bundle.ask_back_order,
|
||||
"status": bundle.status,
|
||||
"version": bundle.version,
|
||||
"created_at": bundle.created_at.isoformat() if bundle.created_at else None,
|
||||
"updated_at": bundle.updated_at.isoformat() if bundle.updated_at else None,
|
||||
}
|
||||
|
||||
async def _invalidate_cache(
|
||||
self,
|
||||
tenant_id: str,
|
||||
scene_key: str,
|
||||
) -> None:
|
||||
"""
|
||||
[AC-SCENE-SLOT-03] 使缓存失效
|
||||
|
||||
Args:
|
||||
tenant_id: 租户 ID
|
||||
scene_key: 场景标识
|
||||
"""
|
||||
try:
|
||||
cache = get_scene_slot_bundle_cache()
|
||||
await cache.invalidate_on_update(tenant_id, scene_key)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AC-SCENE-SLOT-03] Failed to invalidate cache: {e}"
|
||||
)
|
||||
Loading…
Reference in New Issue