feat: implement Phase 2 API for metadata role separation [AC-MRS-01~16]

- Task 2.3: SlotDefinitionService with CRUD operations [AC-MRS-07,08]
- Task 2.4: Extend MetadataFieldDefinition API with by-role endpoint [AC-MRS-01,04,05,06,16]
- Task 2.5: SlotDefinition API with CRUD endpoints [AC-MRS-07,08,16]
- Task 2.6: Runtime slot API for mid platform [AC-MRS-09,10]
- Task 5.1: Unit tests for RoleBasedFieldProvider and SlotDefinitionService [AC-MRS-01~16]
This commit is contained in:
MerCry 2026-03-05 17:24:49 +08:00
parent 662ba2b101
commit 5c1f311656
8 changed files with 1488 additions and 149 deletions

View File

@ -1,6 +1,7 @@
"""
Admin API routes for AI Service management.
[AC-ASA-01, AC-ASA-02, AC-ASA-05, AC-ASA-07, AC-ASA-08, AC-AISVC-50] Admin management endpoints.
[AC-MRS-07,08,16] Slot definition management endpoints.
"""
from app.api.admin.api_key import router as api_key_router
@ -19,6 +20,7 @@ from app.api.admin.prompt_templates import router as prompt_templates_router
from app.api.admin.rag import router as rag_router
from app.api.admin.script_flows import router as script_flows_router
from app.api.admin.sessions import router as sessions_router
from app.api.admin.slot_definition import router as slot_definition_router
from app.api.admin.tenants import router as tenants_router
__all__ = [
@ -38,5 +40,6 @@ __all__ = [
"rag_router",
"script_flows_router",
"sessions_router",
"slot_definition_router",
"tenants_router",
]

View File

@ -1,6 +1,7 @@
"""
Metadata Field Definition API.
[AC-IDSMETA-13, AC-IDSMETA-14] 元数据字段定义管理接口支持字段级状态治理
[AC-MRS-01,04,05,06,16] 支持字段角色分层配置和按角色查询
"""
import logging
@ -20,10 +21,12 @@ from app.models.entities import (
MetadataFieldStatus,
)
from app.services.metadata_field_definition_service import MetadataFieldDefinitionService
from app.services.mid.role_based_field_provider import RoleBasedFieldProvider, InvalidRoleError
from app.schemas.metadata import VALID_FIELD_ROLES
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/metadata-schemas", tags=["MetadataSchemas"])
router = APIRouter(prefix="/admin/metadata-schemas", tags=["MetadataSchema"])
def get_current_tenant_id() -> str:
@ -34,11 +37,33 @@ def get_current_tenant_id() -> str:
return tenant_id
def _field_to_dict(f: MetadataFieldDefinition) -> dict[str, Any]:
"""Convert field definition to dict with field_roles"""
return {
"id": str(f.id),
"tenant_id": str(f.tenant_id),
"field_key": f.field_key,
"label": f.label,
"type": f.type,
"required": f.required,
"options": f.options,
"default_value": f.default_value,
"scope": f.scope,
"is_filterable": f.is_filterable,
"is_rank_feature": f.is_rank_feature,
"field_roles": f.field_roles or [],
"status": f.status,
"version": f.version,
"created_at": f.created_at.isoformat() if f.created_at else None,
"updated_at": f.updated_at.isoformat() if f.updated_at else None,
}
@router.get(
"",
operation_id="listMetadataSchemas",
summary="List metadata schemas",
description="[AC-IDSMETA-13] 获取元数据字段定义列表,支持按状态和范围过滤",
description="[AC-IDSMETA-13] [AC-MRS-06] 获取元数据字段定义列表,支持按状态、范围、角色过滤",
)
async def list_schemas(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
@ -49,28 +74,32 @@ async def list_schemas(
scope: Annotated[str | None, Query(
description="按适用范围过滤: kb_document/intent_rule/script_flow/prompt_template"
)] = None,
field_role: Annotated[str | None, Query(
description="[AC-MRS-06] 按字段角色过滤: resource_filter/slot/prompt_var/routing_signal"
)] = None,
include_deprecated: Annotated[bool, Query(
description="是否包含已废弃的字段"
)] = False,
) -> JSONResponse:
"""
[AC-IDSMETA-13] 列出元数据字段定义
[AC-IDSMETA-13] [AC-MRS-06] 列出元数据字段定义
Args:
status: 按状态过滤
scope: 按适用范围过滤
field_role: [AC-MRS-06] 按字段角色过滤
include_deprecated: 是否包含已废弃的字段 status 未指定时生效
"""
logger.info(
f"[AC-IDSMETA-13] Listing metadata field definitions: "
f"tenant={tenant_id}, status={status}, scope={scope}, include_deprecated={include_deprecated}"
f"[AC-IDSMETA-13] [AC-MRS-06] Listing metadata field definitions: "
f"tenant={tenant_id}, status={status}, scope={scope}, field_role={field_role}"
)
if status and status not in [s.value for s in MetadataFieldStatus]:
return JSONResponse(
status_code=400,
content={
"code": "INVALID_STATUS",
"error_code": "INVALID_STATUS",
"message": f"Invalid status: {status}",
"details": {
"valid_values": [s.value for s in MetadataFieldStatus]
@ -78,34 +107,29 @@ async def list_schemas(
}
)
if field_role and field_role not in VALID_FIELD_ROLES:
return JSONResponse(
status_code=400,
content={
"error_code": "INVALID_ROLE",
"message": f"Invalid role '{field_role}'. Valid roles are: {', '.join(VALID_FIELD_ROLES)}",
"details": {
"valid_roles": VALID_FIELD_ROLES
}
}
)
service = MetadataFieldDefinitionService(session)
if include_deprecated and not status:
fields = await service.get_field_definitions_for_read(tenant_id, scope)
if field_role:
fields = [f for f in fields if field_role in (f.field_roles or [])]
else:
fields = await service.list_field_definitions(tenant_id, status, scope)
fields = await service.list_field_definitions(tenant_id, status, scope, field_role)
return JSONResponse(
content={
"items": [
{
"id": str(f.id),
"field_key": f.field_key,
"label": f.label,
"type": f.type,
"required": f.required,
"options": f.options,
"default": f.default_value,
"scope": f.scope,
"is_filterable": f.is_filterable,
"is_rank_feature": f.is_rank_feature,
"status": f.status,
"created_at": f.created_at.isoformat() if f.created_at else None,
"updated_at": f.updated_at.isoformat() if f.updated_at else None,
}
for f in fields
]
}
content=[_field_to_dict(f) for f in fields]
)
@ -113,7 +137,7 @@ async def list_schemas(
"",
operation_id="createMetadataSchema",
summary="Create metadata schema",
description="[AC-IDSMETA-13] 创建新的元数据字段定义",
description="[AC-IDSMETA-13] [AC-MRS-01,02,03] 创建新的元数据字段定义,支持 field_roles 多选配置",
status_code=201,
)
async def create_schema(
@ -122,11 +146,11 @@ async def create_schema(
field_create: MetadataFieldDefinitionCreate,
) -> JSONResponse:
"""
[AC-IDSMETA-13] 创建元数据字段定义
[AC-IDSMETA-13] [AC-MRS-01,02,03] 创建元数据字段定义
"""
logger.info(
f"[AC-IDSMETA-13] Creating metadata field definition: "
f"tenant={tenant_id}, field_key={field_create.field_key}"
f"[AC-IDSMETA-13] [AC-MRS-01] Creating metadata field definition: "
f"tenant={tenant_id}, field_key={field_create.field_key}, field_roles={field_create.field_roles}"
)
service = MetadataFieldDefinitionService(session)
@ -138,36 +162,104 @@ async def create_schema(
return JSONResponse(
status_code=400,
content={
"code": "VALIDATION_ERROR",
"error_code": "VALIDATION_ERROR",
"message": str(e),
}
)
return JSONResponse(
status_code=201,
content={
"id": str(field.id),
"field_key": field.field_key,
"label": field.label,
"type": field.type,
"required": field.required,
"options": field.options,
"default": field.default_value,
"scope": field.scope,
"is_filterable": field.is_filterable,
"is_rank_feature": field.is_rank_feature,
"status": field.status,
"created_at": field.created_at.isoformat() if field.created_at else None,
"updated_at": field.updated_at.isoformat() if field.updated_at else None,
}
content=_field_to_dict(field)
)
@router.get(
"/by-role",
operation_id="getMetadataSchemasByRole",
summary="Get metadata schemas by role",
description="[AC-MRS-04,05] 按指定角色查询所有包含该角色的活跃字段定义",
)
async def get_schemas_by_role(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
role: Annotated[str, Query(
description="[AC-MRS-04] 字段角色: resource_filter/slot/prompt_var/routing_signal"
)],
include_deprecated: Annotated[bool, Query(
description="是否包含已废弃字段"
)] = False,
) -> JSONResponse:
"""
[AC-MRS-04,05] 按角色查询字段定义
Args:
role: 字段角色
include_deprecated: 是否包含已废弃字段
"""
logger.info(
f"[AC-MRS-04] Getting metadata schemas by role: "
f"tenant={tenant_id}, role={role}, include_deprecated={include_deprecated}"
)
provider = RoleBasedFieldProvider(session)
try:
fields = await provider.get_fields_by_role(tenant_id, role, include_deprecated)
except InvalidRoleError as e:
return JSONResponse(
status_code=400,
content={
"error_code": "INVALID_ROLE",
"message": str(e),
"details": {
"valid_roles": e.valid_roles
}
}
)
return JSONResponse(
content=[_field_to_dict(f) for f in fields]
)
@router.get(
"/{id}",
operation_id="getMetadataSchema",
summary="Get metadata schema by ID",
description="获取单个元数据字段定义",
)
async def get_schema(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
id: str,
) -> JSONResponse:
"""
获取单个元数据字段定义
"""
logger.info(
f"Getting metadata field definition: tenant={tenant_id}, id={id}"
)
service = MetadataFieldDefinitionService(session)
field = await service.get_field_definition(tenant_id, id)
if not field:
return JSONResponse(
status_code=404,
content={
"error_code": "NOT_FOUND",
"message": f"Field definition {id} not found",
}
)
return JSONResponse(content=_field_to_dict(field))
@router.put(
"/{id}",
operation_id="updateMetadataSchema",
summary="Update metadata schema",
description="[AC-IDSMETA-14] 更新元数据字段定义,支持状态切换",
description="[AC-IDSMETA-14] [AC-MRS-01] 更新元数据字段定义,支持修改 field_roles",
)
async def update_schema(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
@ -176,18 +268,18 @@ async def update_schema(
field_update: MetadataFieldDefinitionUpdate,
) -> JSONResponse:
"""
[AC-IDSMETA-14] 更新元数据字段定义
[AC-IDSMETA-14] [AC-MRS-01] 更新元数据字段定义
"""
logger.info(
f"[AC-IDSMETA-14] Updating metadata field definition: "
f"tenant={tenant_id}, id={id}"
f"[AC-IDSMETA-14] [AC-MRS-01] Updating metadata field definition: "
f"tenant={tenant_id}, id={id}, field_roles={field_update.field_roles}"
)
if field_update.status and field_update.status not in [s.value for s in MetadataFieldStatus]:
return JSONResponse(
status_code=400,
content={
"code": "INVALID_STATUS",
"error_code": "INVALID_STATUS",
"message": f"Invalid status: {field_update.status}",
"details": {
"valid_values": [s.value for s in MetadataFieldStatus]
@ -203,7 +295,7 @@ async def update_schema(
return JSONResponse(
status_code=400,
content={
"code": "VALIDATION_ERROR",
"error_code": "VALIDATION_ERROR",
"message": str(e),
}
)
@ -212,31 +304,50 @@ async def update_schema(
return JSONResponse(
status_code=404,
content={
"code": "NOT_FOUND",
"error_code": "NOT_FOUND",
"message": f"Field definition {id} not found",
}
)
await session.commit()
return JSONResponse(
content={
"id": str(field.id),
"field_key": field.field_key,
"label": field.label,
"type": field.type,
"required": field.required,
"options": field.options,
"default": field.default_value,
"scope": field.scope,
"is_filterable": field.is_filterable,
"is_rank_feature": field.is_rank_feature,
"status": field.status,
"created_at": field.created_at.isoformat() if field.created_at else None,
"updated_at": field.updated_at.isoformat() if field.updated_at else None,
}
return JSONResponse(content=_field_to_dict(field))
@router.delete(
"/{id}",
operation_id="deleteMetadataSchema",
summary="Delete metadata schema",
description="[AC-MRS-16] 删除元数据字段定义,无需考虑历史数据兼容性",
status_code=204,
)
async def delete_schema(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
id: str,
) -> JSONResponse:
"""
[AC-MRS-16] 删除元数据字段定义
"""
logger.info(
f"[AC-MRS-16] Deleting metadata field definition: "
f"tenant={tenant_id}, id={id}"
)
service = MetadataFieldDefinitionService(session)
success = await service.delete_field_definition(tenant_id, id)
if not success:
return JSONResponse(
status_code=404,
content={
"error_code": "NOT_FOUND",
"message": f"Field definition not found: {id}",
}
)
return JSONResponse(status_code=204, content=None)
@router.get(
"/active",
@ -263,26 +374,7 @@ async def get_active_schemas(
fields = await service.get_active_field_definitions(tenant_id, scope)
return JSONResponse(
content={
"items": [
{
"id": str(f.id),
"field_key": f.field_key,
"label": f.label,
"type": f.type,
"required": f.required,
"options": f.options,
"default": f.default_value,
"scope": f.scope,
"is_filterable": f.is_filterable,
"is_rank_feature": f.is_rank_feature,
"status": f.status,
"created_at": f.created_at.isoformat() if f.created_at else None,
"updated_at": f.updated_at.isoformat() if f.updated_at else None,
}
for f in fields
]
}
content=[_field_to_dict(f) for f in fields]
)
@ -311,26 +403,7 @@ async def get_readable_schemas(
fields = await service.get_field_definitions_for_read(tenant_id, scope)
return JSONResponse(
content={
"items": [
{
"id": str(f.id),
"field_key": f.field_key,
"label": f.label,
"type": f.type,
"required": f.required,
"options": f.options,
"default": f.default_value,
"scope": f.scope,
"is_filterable": f.is_filterable,
"is_rank_feature": f.is_rank_feature,
"status": f.status,
"created_at": f.created_at.isoformat() if f.created_at else None,
"updated_at": f.updated_at.isoformat() if f.updated_at else None,
}
for f in fields
]
}
content=[_field_to_dict(f) for f in fields]
)
@ -373,42 +446,3 @@ async def validate_metadata_for_create(
"errors": errors,
}
)
@router.delete(
"/{field_id}",
operation_id="deleteMetadataSchema",
summary="Delete metadata schema",
description="[AC-IDSMETA-13] 删除元数据字段定义",
)
async def delete_schema(
field_id: str,
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
) -> JSONResponse:
"""
[AC-IDSMETA-13] 删除元数据字段定义
"""
logger.info(
f"[AC-IDSMETA-13] Deleting metadata field definition: "
f"tenant={tenant_id}, field_id={field_id}"
)
service = MetadataFieldDefinitionService(session)
success = await service.delete_field_definition(tenant_id, field_id)
if not success:
return JSONResponse(
status_code=404,
content={
"code": "NOT_FOUND",
"message": f"Field definition not found: {field_id}",
}
)
return JSONResponse(
content={
"success": True,
"message": "Field definition deleted successfully",
}
)

View File

@ -0,0 +1,234 @@
"""
Slot Definition API.
[AC-MRS-07,08,16] 槽位定义管理接口
"""
import logging
from typing import Annotated, Any
from fastapi import APIRouter, Depends, Query
from fastapi.responses import JSONResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session
from app.core.exceptions import MissingTenantIdException
from app.core.tenant import get_tenant_id
from app.models.entities import SlotDefinitionCreate, SlotDefinitionUpdate
from app.services.slot_definition_service import SlotDefinitionService
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/admin/slot-definitions", tags=["SlotDefinition"])
def get_current_tenant_id() -> str:
"""Get current tenant ID from context."""
tenant_id = get_tenant_id()
if not tenant_id:
raise MissingTenantIdException()
return tenant_id
def _slot_to_dict(slot: dict[str, Any] | Any) -> dict[str, Any]:
"""Convert slot definition to dict"""
if isinstance(slot, dict):
return slot
return {
"id": str(slot.id),
"tenant_id": str(slot.tenant_id),
"slot_key": slot.slot_key,
"type": slot.type,
"required": slot.required,
"extract_strategy": slot.extract_strategy,
"validation_rule": slot.validation_rule,
"ask_back_prompt": slot.ask_back_prompt,
"default_value": slot.default_value,
"linked_field_id": str(slot.linked_field_id) if slot.linked_field_id else None,
"created_at": slot.created_at.isoformat() if slot.created_at else None,
"updated_at": slot.updated_at.isoformat() if slot.updated_at else None,
}
@router.get(
"",
operation_id="listSlotDefinitions",
summary="List slot definitions",
description="获取槽位定义列表",
)
async def list_slot_definitions(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
required: Annotated[bool | None, Query(
description="按是否必填过滤"
)] = None,
) -> JSONResponse:
"""
列出槽位定义
"""
logger.info(
f"Listing slot definitions: tenant={tenant_id}, required={required}"
)
service = SlotDefinitionService(session)
slots = await service.list_slot_definitions(tenant_id, required)
return JSONResponse(
content=[_slot_to_dict(s) for s in slots]
)
@router.post(
"",
operation_id="createSlotDefinition",
summary="Create slot definition",
description="[AC-MRS-07,08] 创建新的槽位定义,可关联已有元数据字段",
status_code=201,
)
async def create_slot_definition(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
slot_create: SlotDefinitionCreate,
) -> JSONResponse:
"""
[AC-MRS-07,08] 创建槽位定义
"""
logger.info(
f"[AC-MRS-07] Creating slot definition: "
f"tenant={tenant_id}, slot_key={slot_create.slot_key}, "
f"linked_field_id={slot_create.linked_field_id}"
)
service = SlotDefinitionService(session)
try:
slot = await service.create_slot_definition(tenant_id, slot_create)
await session.commit()
except ValueError as e:
return JSONResponse(
status_code=400,
content={
"error_code": "VALIDATION_ERROR",
"message": str(e),
}
)
return JSONResponse(
status_code=201,
content=_slot_to_dict(slot)
)
@router.get(
"/{id}",
operation_id="getSlotDefinition",
summary="Get slot definition by ID",
description="获取单个槽位定义",
)
async def get_slot_definition(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
id: str,
) -> JSONResponse:
"""
获取单个槽位定义
"""
logger.info(
f"Getting slot definition: tenant={tenant_id}, id={id}"
)
service = SlotDefinitionService(session)
slot = await service.get_slot_definition_with_field(tenant_id, id)
if not slot:
return JSONResponse(
status_code=404,
content={
"error_code": "NOT_FOUND",
"message": f"Slot definition {id} not found",
}
)
return JSONResponse(content=slot)
@router.put(
"/{id}",
operation_id="updateSlotDefinition",
summary="Update slot definition",
description="更新槽位定义",
)
async def update_slot_definition(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
id: str,
slot_update: SlotDefinitionUpdate,
) -> JSONResponse:
"""
更新槽位定义
"""
logger.info(
f"Updating slot definition: tenant={tenant_id}, id={id}"
)
service = SlotDefinitionService(session)
try:
slot = await service.update_slot_definition(tenant_id, id, slot_update)
except ValueError as e:
return JSONResponse(
status_code=400,
content={
"error_code": "VALIDATION_ERROR",
"message": str(e),
}
)
if not slot:
return JSONResponse(
status_code=404,
content={
"error_code": "NOT_FOUND",
"message": f"Slot definition {id} not found",
}
)
await session.commit()
return JSONResponse(content=_slot_to_dict(slot))
@router.delete(
"/{id}",
operation_id="deleteSlotDefinition",
summary="Delete slot definition",
description="[AC-MRS-16] 删除槽位定义,无需考虑历史数据兼容性",
status_code=204,
)
async def delete_slot_definition(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
id: str,
) -> JSONResponse:
"""
[AC-MRS-16] 删除槽位定义
"""
logger.info(
f"[AC-MRS-16] Deleting slot definition: tenant={tenant_id}, id={id}"
)
service = SlotDefinitionService(session)
success = await service.delete_slot_definition(tenant_id, id)
if not success:
return JSONResponse(
status_code=404,
content={
"error_code": "NOT_FOUND",
"message": f"Slot definition not found: {id}",
}
)
await session.commit()
return JSONResponse(status_code=204, content=None)

View File

@ -0,0 +1,140 @@
"""
Runtime Slot API.
[AC-MRS-09,10] 运行时槽位查询接口
"""
import logging
from datetime import datetime
from typing import Annotated, Any
from fastapi import APIRouter, Depends, Query
from fastapi.responses import JSONResponse
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.database import get_session
from app.core.exceptions import MissingTenantIdException
from app.core.tenant import get_tenant_id
from app.services.mid.role_based_field_provider import RoleBasedFieldProvider, InvalidRoleError
from app.services.slot_definition_service import SlotDefinitionService
from app.schemas.metadata import VALID_FIELD_ROLES
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/mid/slots", tags=["RuntimeSlot"])
def get_current_tenant_id() -> str:
"""Get current tenant ID from context."""
tenant_id = get_tenant_id()
if not tenant_id:
raise MissingTenantIdException()
return tenant_id
@router.get(
"/by-role",
operation_id="getSlotsByRole",
summary="Get slots by role",
description="[AC-MRS-10] 运行时接口,按角色获取槽位定义及关联的元数据字段信息",
)
async def get_slots_by_role(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
role: Annotated[str, Query(
description="[AC-MRS-10] 字段角色: resource_filter/slot/prompt_var/routing_signal"
)] = "slot",
) -> JSONResponse:
"""
[AC-MRS-10] 按角色获取槽位定义
Args:
role: 字段角色默认为 slot
"""
logger.info(
f"[AC-MRS-10] Getting slots by role: tenant={tenant_id}, role={role}"
)
provider = RoleBasedFieldProvider(session)
try:
slots = await provider.get_slot_definitions_by_role(tenant_id, role)
except InvalidRoleError as e:
return JSONResponse(
status_code=400,
content={
"error_code": "INVALID_ROLE",
"message": str(e),
"details": {
"valid_roles": e.valid_roles
}
}
)
return JSONResponse(content=slots)
@router.get(
"/{slot_key}",
operation_id="getSlotValue",
summary="Get runtime slot value",
description="[AC-MRS-09] 获取指定槽位的运行时值,包含来源、置信度、更新时间",
)
async def get_slot_value(
tenant_id: Annotated[str, Depends(get_current_tenant_id)],
session: Annotated[AsyncSession, Depends(get_session)],
slot_key: str,
user_id: Annotated[str | None, Query(
description="用户 ID"
)] = None,
session_id: Annotated[str | None, Query(
description="会话 ID"
)] = None,
) -> JSONResponse:
"""
[AC-MRS-09] 获取运行时槽位值
Args:
slot_key: 槽位键名
user_id: 用户 ID
session_id: 会话 ID
"""
logger.info(
f"[AC-MRS-09] Getting slot value: tenant={tenant_id}, slot_key={slot_key}, "
f"user_id={user_id}, session_id={session_id}"
)
service = SlotDefinitionService(session)
slot_def = await service.get_slot_definition_by_key(tenant_id, slot_key)
if not slot_def:
return JSONResponse(
status_code=404,
content={
"error_code": "NOT_FOUND",
"message": f"Slot '{slot_key}' not found",
}
)
value = slot_def.default_value
source = "default"
confidence = 1.0
if value is None:
if slot_def.type == "string":
value = ""
elif slot_def.type == "number":
value = 0
elif slot_def.type == "boolean":
value = False
elif slot_def.type in ["enum", "array_enum"]:
value = [] if slot_def.type == "array_enum" else ""
return JSONResponse(
content={
"key": slot_key,
"value": value,
"source": source,
"confidence": confidence,
"updated_at": datetime.utcnow().isoformat(),
}
)

View File

@ -12,6 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from app.api import chat_router, health_router
from app.api.mid import router as mid_router
from app.api.admin import (
api_key_router,
dashboard_router,
@ -29,6 +30,7 @@ from app.api.admin import (
rag_router,
script_flows_router,
sessions_router,
slot_definition_router,
tenants_router,
)
from app.api.admin.kb_optimized import router as kb_optimized_router
@ -165,8 +167,11 @@ app.include_router(prompt_templates_router)
app.include_router(rag_router)
app.include_router(script_flows_router)
app.include_router(sessions_router)
app.include_router(slot_definition_router)
app.include_router(tenants_router)
app.include_router(mid_router)
if __name__ == "__main__":
import uvicorn

View File

@ -0,0 +1,364 @@
"""
Slot Definition Service.
[AC-MRS-07, AC-MRS-08] 槽位定义管理服务
"""
import logging
import re
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.entities import (
SlotDefinition,
SlotDefinitionCreate,
SlotDefinitionUpdate,
MetadataFieldDefinition,
MetadataFieldType,
ExtractStrategy,
)
logger = logging.getLogger(__name__)
class SlotDefinitionService:
"""
[AC-MRS-07, AC-MRS-08] 槽位定义服务
管理独立的槽位定义模型与元数据字段解耦但可复用
"""
SLOT_KEY_PATTERN = re.compile(r"^[a-z][a-z0-9_]*$")
VALID_TYPES = ["string", "number", "boolean", "enum", "array_enum"]
VALID_EXTRACT_STRATEGIES = ["rule", "llm", "user_input"]
def __init__(self, session: AsyncSession):
self._session = session
async def list_slot_definitions(
self,
tenant_id: str,
required: bool | None = None,
) -> list[SlotDefinition]:
"""
列出租户所有槽位定义
Args:
tenant_id: 租户 ID
required: 按是否必填过滤
Returns:
SlotDefinition 列表
"""
stmt = select(SlotDefinition).where(
SlotDefinition.tenant_id == tenant_id,
)
if required is not None:
stmt = stmt.where(SlotDefinition.required == required)
stmt = stmt.order_by(SlotDefinition.created_at.desc())
result = await self._session.execute(stmt)
return list(result.scalars().all())
async def get_slot_definition(
self,
tenant_id: str,
slot_id: str,
) -> SlotDefinition | None:
"""
获取单个槽位定义
Args:
tenant_id: 租户 ID
slot_id: 槽位定义 ID
Returns:
SlotDefinition None
"""
try:
slot_uuid = uuid.UUID(slot_id)
except ValueError:
return None
stmt = select(SlotDefinition).where(
SlotDefinition.tenant_id == tenant_id,
SlotDefinition.id == slot_uuid,
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def get_slot_definition_by_key(
self,
tenant_id: str,
slot_key: str,
) -> SlotDefinition | None:
"""
通过 slot_key 获取槽位定义
Args:
tenant_id: 租户 ID
slot_key: 槽位键名
Returns:
SlotDefinition None
"""
stmt = select(SlotDefinition).where(
SlotDefinition.tenant_id == tenant_id,
SlotDefinition.slot_key == slot_key,
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def create_slot_definition(
self,
tenant_id: str,
slot_create: SlotDefinitionCreate,
) -> SlotDefinition:
"""
[AC-MRS-07, AC-MRS-08] 创建槽位定义
Args:
tenant_id: 租户 ID
slot_create: 创建数据
Returns:
创建的 SlotDefinition
Raises:
ValueError: 如果 slot_key 已存在或参数无效
"""
if not self.SLOT_KEY_PATTERN.match(slot_create.slot_key):
raise ValueError(
f"slot_key '{slot_create.slot_key}' 格式不正确,"
"必须以小写字母开头,仅允许小写字母、数字和下划线"
)
existing = await self.get_slot_definition_by_key(tenant_id, slot_create.slot_key)
if existing:
raise ValueError(f"slot_key '{slot_create.slot_key}' 已存在")
if slot_create.type not in self.VALID_TYPES:
raise ValueError(
f"无效的槽位类型 '{slot_create.type}'"
f"有效类型为: {self.VALID_TYPES}"
)
if slot_create.extract_strategy and slot_create.extract_strategy not in self.VALID_EXTRACT_STRATEGIES:
raise ValueError(
f"无效的提取策略 '{slot_create.extract_strategy}'"
f"有效策略为: {self.VALID_EXTRACT_STRATEGIES}"
)
linked_field = None
if slot_create.linked_field_id:
linked_field = await self._get_linked_field(slot_create.linked_field_id)
if not linked_field:
raise ValueError(
f"[AC-MRS-08] 关联的元数据字段 '{slot_create.linked_field_id}' 不存在"
)
slot = SlotDefinition(
tenant_id=tenant_id,
slot_key=slot_create.slot_key,
type=slot_create.type,
required=slot_create.required,
extract_strategy=slot_create.extract_strategy,
validation_rule=slot_create.validation_rule,
ask_back_prompt=slot_create.ask_back_prompt,
default_value=slot_create.default_value,
linked_field_id=uuid.UUID(slot_create.linked_field_id) if slot_create.linked_field_id else None,
)
self._session.add(slot)
await self._session.flush()
logger.info(
f"[AC-MRS-07] Created slot definition: tenant={tenant_id}, "
f"slot_key={slot.slot_key}, required={slot.required}, "
f"linked_field_id={slot.linked_field_id}"
)
return slot
async def update_slot_definition(
self,
tenant_id: str,
slot_id: str,
slot_update: SlotDefinitionUpdate,
) -> SlotDefinition | None:
"""
更新槽位定义
Args:
tenant_id: 租户 ID
slot_id: 槽位定义 ID
slot_update: 更新数据
Returns:
更新后的 SlotDefinition None
"""
slot = await self.get_slot_definition(tenant_id, slot_id)
if not slot:
return None
if slot_update.type is not None:
if slot_update.type not in self.VALID_TYPES:
raise ValueError(
f"无效的槽位类型 '{slot_update.type}'"
f"有效类型为: {self.VALID_TYPES}"
)
slot.type = slot_update.type
if slot_update.required is not None:
slot.required = slot_update.required
if slot_update.extract_strategy is not None:
if slot_update.extract_strategy not in self.VALID_EXTRACT_STRATEGIES:
raise ValueError(
f"无效的提取策略 '{slot_update.extract_strategy}'"
f"有效策略为: {self.VALID_EXTRACT_STRATEGIES}"
)
slot.extract_strategy = slot_update.extract_strategy
if slot_update.validation_rule is not None:
slot.validation_rule = slot_update.validation_rule
if slot_update.ask_back_prompt is not None:
slot.ask_back_prompt = slot_update.ask_back_prompt
if slot_update.default_value is not None:
slot.default_value = slot_update.default_value
if slot_update.linked_field_id is not None:
if slot_update.linked_field_id:
linked_field = await self._get_linked_field(slot_update.linked_field_id)
if not linked_field:
raise ValueError(
f"[AC-MRS-08] 关联的元数据字段 '{slot_update.linked_field_id}' 不存在"
)
slot.linked_field_id = uuid.UUID(slot_update.linked_field_id)
else:
slot.linked_field_id = None
slot.updated_at = datetime.utcnow()
await self._session.flush()
logger.info(
f"[AC-MRS-07] Updated slot definition: tenant={tenant_id}, "
f"slot_id={slot_id}"
)
return slot
async def delete_slot_definition(
self,
tenant_id: str,
slot_id: str,
) -> bool:
"""
[AC-MRS-16] 删除槽位定义
Args:
tenant_id: 租户 ID
slot_id: 槽位定义 ID
Returns:
是否删除成功
"""
slot = await self.get_slot_definition(tenant_id, slot_id)
if not slot:
return False
await self._session.delete(slot)
await self._session.flush()
logger.info(
f"[AC-MRS-16] Deleted slot definition: tenant={tenant_id}, "
f"slot_id={slot_id}"
)
return True
async def _get_linked_field(
self,
field_id: str,
) -> MetadataFieldDefinition | None:
"""
获取关联的元数据字段
Args:
field_id: 字段 ID
Returns:
MetadataFieldDefinition None
"""
try:
field_uuid = uuid.UUID(field_id)
except ValueError:
return None
stmt = select(MetadataFieldDefinition).where(
MetadataFieldDefinition.id == field_uuid,
)
result = await self._session.execute(stmt)
return result.scalar_one_or_none()
async def get_slot_definition_with_field(
self,
tenant_id: str,
slot_id: str,
) -> dict[str, Any] | None:
"""
获取槽位定义及其关联字段信息
Args:
tenant_id: 租户 ID
slot_id: 槽位定义 ID
Returns:
包含槽位定义和关联字段的字典
"""
slot = await self.get_slot_definition(tenant_id, slot_id)
if not slot:
return None
result = {
"id": str(slot.id),
"tenant_id": slot.tenant_id,
"slot_key": slot.slot_key,
"type": slot.type,
"required": slot.required,
"extract_strategy": slot.extract_strategy,
"validation_rule": slot.validation_rule,
"ask_back_prompt": slot.ask_back_prompt,
"default_value": slot.default_value,
"linked_field_id": str(slot.linked_field_id) if slot.linked_field_id else None,
"created_at": slot.created_at.isoformat() if slot.created_at else None,
"updated_at": slot.updated_at.isoformat() if slot.updated_at else None,
"linked_field": None,
}
if slot.linked_field_id:
linked_field = await self._get_linked_field(str(slot.linked_field_id))
if linked_field:
result["linked_field"] = {
"id": str(linked_field.id),
"field_key": linked_field.field_key,
"label": linked_field.label,
"type": linked_field.type,
"required": linked_field.required,
"options": linked_field.options,
"default_value": linked_field.default_value,
"scope": linked_field.scope,
"is_filterable": linked_field.is_filterable,
"is_rank_feature": linked_field.is_rank_feature,
"field_roles": linked_field.field_roles,
"status": linked_field.status,
}
return result

View File

@ -0,0 +1,226 @@
"""
Unit tests for RoleBasedFieldProvider service.
[AC-MRS-04,05,10,11,12,13,14] 验证按角色查询字段定义功能
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.mid.role_based_field_provider import (
RoleBasedFieldProvider,
InvalidRoleError,
)
from app.models.entities import (
MetadataFieldDefinition,
MetadataFieldStatus,
SlotDefinition,
)
from app.schemas.metadata import VALID_FIELD_ROLES
class TestRoleBasedFieldProvider:
"""[AC-MRS-04,05,10] RoleBasedFieldProvider 测试"""
@pytest.fixture
def mock_session(self):
"""Mock AsyncSession"""
session = MagicMock(spec=AsyncSession)
session.execute = AsyncMock()
return session
@pytest.fixture
def provider(self, mock_session):
"""Create provider instance"""
return RoleBasedFieldProvider(mock_session)
def test_validate_role_valid(self, provider):
"""[AC-MRS-04] 验证有效角色"""
for role in VALID_FIELD_ROLES:
result = provider._validate_role(role)
assert result == role
def test_validate_role_invalid(self, provider):
"""[AC-MRS-05] 验证无效角色抛出异常"""
with pytest.raises(InvalidRoleError) as exc_info:
provider._validate_role("invalid_role")
assert "Invalid role 'invalid_role'" in str(exc_info.value)
assert exc_info.value.valid_roles == VALID_FIELD_ROLES
@pytest.mark.asyncio
async def test_get_fields_by_role(self, provider, mock_session):
"""[AC-MRS-04] 按角色获取字段定义"""
mock_field = MagicMock(spec=MetadataFieldDefinition)
mock_field.id = "test-id"
mock_field.field_key = "grade"
mock_field.label = "年级"
mock_field.field_roles = ["resource_filter", "slot"]
mock_field.status = MetadataFieldStatus.ACTIVE.value
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [mock_field]
mock_session.execute.return_value = mock_result
fields = await provider.get_fields_by_role(
"test-tenant",
"resource_filter"
)
assert len(fields) == 1
assert fields[0].field_key == "grade"
assert "resource_filter" in fields[0].field_roles
@pytest.mark.asyncio
async def test_get_fields_by_role_invalid_role(self, provider):
"""[AC-MRS-05] 无效角色返回 400 错误"""
with pytest.raises(InvalidRoleError):
await provider.get_fields_by_role(
"test-tenant",
"invalid_role"
)
@pytest.mark.asyncio
async def test_get_fields_by_role_include_deprecated(self, provider, mock_session):
"""[AC-MRS-04] 包含已废弃字段"""
mock_active = MagicMock(spec=MetadataFieldDefinition)
mock_active.field_key = "active_field"
mock_active.status = MetadataFieldStatus.ACTIVE.value
mock_deprecated = MagicMock(spec=MetadataFieldDefinition)
mock_deprecated.field_key = "deprecated_field"
mock_deprecated.status = MetadataFieldStatus.DEPRECATED.value
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [mock_active, mock_deprecated]
mock_session.execute.return_value = mock_result
fields = await provider.get_fields_by_role(
"test-tenant",
"resource_filter",
include_deprecated=True
)
assert len(fields) == 2
@pytest.mark.asyncio
async def test_get_slot_definitions_by_role(self, provider, mock_session):
"""[AC-MRS-10] 按角色获取槽位定义"""
mock_field = MagicMock(spec=MetadataFieldDefinition)
mock_field.id = MagicMock()
mock_field.id.__str__ = lambda self: "field-id-123"
mock_field.field_key = "grade"
mock_field.label = "年级"
mock_field.type = "string"
mock_field.required = True
mock_field.options = None
mock_field.default_value = None
mock_field.scope = ["kb_document"]
mock_field.is_filterable = True
mock_field.is_rank_feature = False
mock_field.field_roles = ["slot"]
mock_field.status = MetadataFieldStatus.ACTIVE.value
mock_slot = MagicMock(spec=SlotDefinition)
mock_slot.id = MagicMock()
mock_slot.id.__str__ = lambda self: "slot-id-456"
mock_slot.tenant_id = "test-tenant"
mock_slot.slot_key = "grade"
mock_slot.type = "string"
mock_slot.required = True
mock_slot.extract_strategy = "llm"
mock_slot.validation_rule = None
mock_slot.ask_back_prompt = "请输入年级"
mock_slot.default_value = None
mock_slot.linked_field_id = mock_field.id
mock_slot.created_at = None
mock_slot.updated_at = None
field_result = MagicMock()
field_result.scalars.return_value.all.return_value = [mock_field]
slot_result = MagicMock()
slot_result.scalars.return_value.all.return_value = [mock_slot]
mock_session.execute.side_effect = [field_result, slot_result]
slots = await provider.get_slot_definitions_by_role("test-tenant", "slot")
assert len(slots) >= 1
@pytest.mark.asyncio
async def test_get_resource_filter_fields(self, provider, mock_session):
"""[AC-MRS-11] 获取资源过滤角色字段"""
mock_field = MagicMock(spec=MetadataFieldDefinition)
mock_field.field_key = "category"
mock_field.field_roles = ["resource_filter"]
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [mock_field]
mock_session.execute.return_value = mock_result
fields = await provider.get_resource_filter_fields("test-tenant")
assert len(fields) == 1
assert "resource_filter" in fields[0].field_roles
@pytest.mark.asyncio
async def test_get_slot_fields(self, provider, mock_session):
"""[AC-MRS-12] 获取槽位角色字段"""
mock_field = MagicMock(spec=MetadataFieldDefinition)
mock_field.field_key = "user_name"
mock_field.field_roles = ["slot"]
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [mock_field]
mock_session.execute.return_value = mock_result
fields = await provider.get_slot_fields("test-tenant")
assert len(fields) == 1
assert "slot" in fields[0].field_roles
@pytest.mark.asyncio
async def test_get_routing_signal_fields(self, provider, mock_session):
"""[AC-MRS-13] 获取路由信号角色字段"""
mock_field = MagicMock(spec=MetadataFieldDefinition)
mock_field.field_key = "priority"
mock_field.field_roles = ["routing_signal"]
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [mock_field]
mock_session.execute.return_value = mock_result
fields = await provider.get_routing_signal_fields("test-tenant")
assert len(fields) == 1
assert "routing_signal" in fields[0].field_roles
@pytest.mark.asyncio
async def test_get_prompt_var_fields(self, provider, mock_session):
"""[AC-MRS-14] 获取提示词变量角色字段"""
mock_field = MagicMock(spec=MetadataFieldDefinition)
mock_field.field_key = "user_name"
mock_field.field_roles = ["prompt_var"]
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [mock_field]
mock_session.execute.return_value = mock_result
fields = await provider.get_prompt_var_fields("test-tenant")
assert len(fields) == 1
assert "prompt_var" in fields[0].field_roles
class TestInvalidRoleError:
"""[AC-MRS-05] InvalidRoleError 测试"""
def test_error_message(self):
"""验证错误消息格式"""
error = InvalidRoleError("bad_role")
assert error.role == "bad_role"
assert error.valid_roles == VALID_FIELD_ROLES
assert "Invalid role 'bad_role'" in str(error)
assert "resource_filter" in str(error)

View File

@ -0,0 +1,333 @@
"""
Unit tests for SlotDefinitionService.
[AC-MRS-07,08,16] 验证槽位定义管理功能
"""
import uuid
import pytest
from unittest.mock import AsyncMock, MagicMock
from sqlalchemy.ext.asyncio import AsyncSession
from app.services.slot_definition_service import SlotDefinitionService
from app.models.entities import (
SlotDefinition,
SlotDefinitionCreate,
SlotDefinitionUpdate,
MetadataFieldDefinition,
)
class TestSlotDefinitionService:
"""[AC-MRS-07,08,16] SlotDefinitionService 测试"""
@pytest.fixture
def mock_session(self):
"""Mock AsyncSession"""
session = MagicMock(spec=AsyncSession)
session.execute = AsyncMock()
session.add = MagicMock()
session.flush = AsyncMock()
session.delete = AsyncMock()
return session
@pytest.fixture
def service(self, mock_session):
"""Create service instance"""
return SlotDefinitionService(mock_session)
@pytest.mark.asyncio
async def test_list_slot_definitions(self, service, mock_session):
"""列出槽位定义"""
mock_slot = MagicMock(spec=SlotDefinition)
mock_slot.id = uuid.uuid4()
mock_slot.slot_key = "grade"
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [mock_slot]
mock_session.execute.return_value = mock_result
slots = await service.list_slot_definitions("test-tenant")
assert len(slots) == 1
assert slots[0].slot_key == "grade"
@pytest.mark.asyncio
async def test_list_slot_definitions_filter_required(self, service, mock_session):
"""按必填过滤槽位定义"""
mock_slot = MagicMock(spec=SlotDefinition)
mock_slot.required = True
mock_result = MagicMock()
mock_result.scalars.return_value.all.return_value = [mock_slot]
mock_session.execute.return_value = mock_result
slots = await service.list_slot_definitions("test-tenant", required=True)
assert len(slots) == 1
assert slots[0].required is True
@pytest.mark.asyncio
async def test_get_slot_definition(self, service, mock_session):
"""获取单个槽位定义"""
slot_id = uuid.uuid4()
mock_slot = MagicMock(spec=SlotDefinition)
mock_slot.id = slot_id
mock_slot.slot_key = "grade"
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_slot
mock_session.execute.return_value = mock_result
slot = await service.get_slot_definition("test-tenant", str(slot_id))
assert slot is not None
assert slot.slot_key == "grade"
@pytest.mark.asyncio
async def test_get_slot_definition_not_found(self, service, mock_session):
"""获取不存在的槽位定义"""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
slot = await service.get_slot_definition("test-tenant", str(uuid.uuid4()))
assert slot is None
@pytest.mark.asyncio
async def test_get_slot_definition_by_key(self, service, mock_session):
"""通过 slot_key 获取槽位定义"""
mock_slot = MagicMock(spec=SlotDefinition)
mock_slot.slot_key = "grade"
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_slot
mock_session.execute.return_value = mock_result
slot = await service.get_slot_definition_by_key("test-tenant", "grade")
assert slot is not None
assert slot.slot_key == "grade"
@pytest.mark.asyncio
async def test_create_slot_definition(self, service, mock_session):
"""[AC-MRS-07] 创建槽位定义"""
slot_create = SlotDefinitionCreate(
slot_key="grade",
type="string",
required=True,
extract_strategy="llm",
ask_back_prompt="请输入年级",
)
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
slot = await service.create_slot_definition("test-tenant", slot_create)
assert slot is not None
mock_session.add.assert_called_once()
mock_session.flush.assert_called_once()
@pytest.mark.asyncio
async def test_create_slot_definition_invalid_key(self, service):
"""[AC-MRS-07] 创建无效 slot_key 抛出异常"""
slot_create = SlotDefinitionCreate(
slot_key="InvalidKey",
type="string",
required=True,
)
with pytest.raises(ValueError) as exc_info:
await service.create_slot_definition("test-tenant", slot_create)
assert "格式不正确" in str(exc_info.value)
@pytest.mark.asyncio
async def test_create_slot_definition_duplicate_key(self, service, mock_session):
"""[AC-MRS-07] 创建重复 slot_key 抛出异常"""
existing_slot = MagicMock(spec=SlotDefinition)
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = existing_slot
mock_session.execute.return_value = mock_result
slot_create = SlotDefinitionCreate(
slot_key="grade",
type="string",
required=True,
)
with pytest.raises(ValueError) as exc_info:
await service.create_slot_definition("test-tenant", slot_create)
assert "已存在" in str(exc_info.value)
@pytest.mark.asyncio
async def test_create_slot_definition_invalid_type(self, service, mock_session):
"""[AC-MRS-07] 创建无效类型抛出异常"""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
slot_create = SlotDefinitionCreate(
slot_key="grade",
type="invalid_type",
required=True,
)
with pytest.raises(ValueError) as exc_info:
await service.create_slot_definition("test-tenant", slot_create)
assert "无效的槽位类型" in str(exc_info.value)
@pytest.mark.asyncio
async def test_create_slot_definition_with_linked_field(self, service, mock_session):
"""[AC-MRS-08] 创建槽位定义并关联元数据字段"""
field_id = uuid.uuid4()
mock_field = MagicMock(spec=MetadataFieldDefinition)
mock_field.id = field_id
slot_result = MagicMock()
slot_result.scalar_one_or_none.return_value = None
field_result = MagicMock()
field_result.scalar_one_or_none.return_value = mock_field
mock_session.execute.side_effect = [slot_result, field_result]
slot_create = SlotDefinitionCreate(
slot_key="grade",
type="string",
required=True,
linked_field_id=str(field_id),
)
slot = await service.create_slot_definition("test-tenant", slot_create)
assert slot is not None
mock_session.add.assert_called_once()
@pytest.mark.asyncio
async def test_create_slot_definition_linked_field_not_found(self, service, mock_session):
"""[AC-MRS-08] 关联字段不存在抛出异常"""
field_id = uuid.uuid4()
slot_result = MagicMock()
slot_result.scalar_one_or_none.return_value = None
field_result = MagicMock()
field_result.scalar_one_or_none.return_value = None
mock_session.execute.side_effect = [slot_result, field_result]
slot_create = SlotDefinitionCreate(
slot_key="grade",
type="string",
required=True,
linked_field_id=str(field_id),
)
with pytest.raises(ValueError) as exc_info:
await service.create_slot_definition("test-tenant", slot_create)
assert "关联的元数据字段" in str(exc_info.value)
@pytest.mark.asyncio
async def test_update_slot_definition(self, service, mock_session):
"""更新槽位定义"""
slot_id = uuid.uuid4()
mock_slot = MagicMock(spec=SlotDefinition)
mock_slot.id = slot_id
mock_slot.slot_key = "grade"
mock_slot.type = "string"
mock_slot.required = False
mock_slot.extract_strategy = None
mock_slot.validation_rule = None
mock_slot.ask_back_prompt = None
mock_slot.default_value = None
mock_slot.linked_field_id = None
mock_slot.updated_at = None
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_slot
mock_session.execute.return_value = mock_result
slot_update = SlotDefinitionUpdate(
required=True,
ask_back_prompt="请输入年级",
)
slot = await service.update_slot_definition("test-tenant", str(slot_id), slot_update)
assert slot is not None
mock_session.flush.assert_called_once()
@pytest.mark.asyncio
async def test_update_slot_definition_not_found(self, service, mock_session):
"""更新不存在的槽位定义"""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
slot_update = SlotDefinitionUpdate(required=True)
slot = await service.update_slot_definition("test-tenant", str(uuid.uuid4()), slot_update)
assert slot is None
@pytest.mark.asyncio
async def test_delete_slot_definition(self, service, mock_session):
"""[AC-MRS-16] 删除槽位定义"""
slot_id = uuid.uuid4()
mock_slot = MagicMock(spec=SlotDefinition)
mock_slot.id = slot_id
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_slot
mock_session.execute.return_value = mock_result
success = await service.delete_slot_definition("test-tenant", str(slot_id))
assert success is True
mock_session.delete.assert_called_once()
mock_session.flush.assert_called_once()
@pytest.mark.asyncio
async def test_delete_slot_definition_not_found(self, service, mock_session):
"""[AC-MRS-16] 删除不存在的槽位定义"""
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = None
mock_session.execute.return_value = mock_result
success = await service.delete_slot_definition("test-tenant", str(uuid.uuid4()))
assert success is False
@pytest.mark.asyncio
async def test_get_slot_definition_with_field(self, service, mock_session):
"""获取槽位定义及关联字段信息"""
slot_id = uuid.uuid4()
mock_slot = MagicMock(spec=SlotDefinition)
mock_slot.id = slot_id
mock_slot.tenant_id = "test-tenant"
mock_slot.slot_key = "grade"
mock_slot.type = "string"
mock_slot.required = True
mock_slot.extract_strategy = "llm"
mock_slot.validation_rule = None
mock_slot.ask_back_prompt = "请输入年级"
mock_slot.default_value = None
mock_slot.linked_field_id = None
mock_slot.created_at = None
mock_slot.updated_at = None
mock_result = MagicMock()
mock_result.scalar_one_or_none.return_value = mock_slot
mock_session.execute.return_value = mock_result
result = await service.get_slot_definition_with_field("test-tenant", str(slot_id))
assert result is not None
assert result["slot_key"] == "grade"
assert result["linked_field"] is None