diff --git a/ai-service/app/api/admin/__init__.py b/ai-service/app/api/admin/__init__.py index 2091088..60e120a 100644 --- a/ai-service/app/api/admin/__init__.py +++ b/ai-service/app/api/admin/__init__.py @@ -1,6 +1,7 @@ """ Admin API routes for AI Service management. [AC-ASA-01, AC-ASA-02, AC-ASA-05, AC-ASA-07, AC-ASA-08, AC-AISVC-50] Admin management endpoints. +[AC-MRS-07,08,16] Slot definition management endpoints. """ from app.api.admin.api_key import router as api_key_router @@ -19,6 +20,7 @@ from app.api.admin.prompt_templates import router as prompt_templates_router from app.api.admin.rag import router as rag_router from app.api.admin.script_flows import router as script_flows_router from app.api.admin.sessions import router as sessions_router +from app.api.admin.slot_definition import router as slot_definition_router from app.api.admin.tenants import router as tenants_router __all__ = [ @@ -38,5 +40,6 @@ __all__ = [ "rag_router", "script_flows_router", "sessions_router", + "slot_definition_router", "tenants_router", ] diff --git a/ai-service/app/api/admin/metadata_field_definition.py b/ai-service/app/api/admin/metadata_field_definition.py index c3f5978..846c162 100644 --- a/ai-service/app/api/admin/metadata_field_definition.py +++ b/ai-service/app/api/admin/metadata_field_definition.py @@ -1,6 +1,7 @@ """ Metadata Field Definition API. [AC-IDSMETA-13, AC-IDSMETA-14] 元数据字段定义管理接口,支持字段级状态治理。 +[AC-MRS-01,04,05,06,16] 支持字段角色分层配置和按角色查询。 """ import logging @@ -20,10 +21,12 @@ from app.models.entities import ( MetadataFieldStatus, ) from app.services.metadata_field_definition_service import MetadataFieldDefinitionService +from app.services.mid.role_based_field_provider import RoleBasedFieldProvider, InvalidRoleError +from app.schemas.metadata import VALID_FIELD_ROLES logger = logging.getLogger(__name__) -router = APIRouter(prefix="/admin/metadata-schemas", tags=["MetadataSchemas"]) +router = APIRouter(prefix="/admin/metadata-schemas", tags=["MetadataSchema"]) def get_current_tenant_id() -> str: @@ -34,11 +37,33 @@ def get_current_tenant_id() -> str: return tenant_id +def _field_to_dict(f: MetadataFieldDefinition) -> dict[str, Any]: + """Convert field definition to dict with field_roles""" + return { + "id": str(f.id), + "tenant_id": str(f.tenant_id), + "field_key": f.field_key, + "label": f.label, + "type": f.type, + "required": f.required, + "options": f.options, + "default_value": f.default_value, + "scope": f.scope, + "is_filterable": f.is_filterable, + "is_rank_feature": f.is_rank_feature, + "field_roles": f.field_roles or [], + "status": f.status, + "version": f.version, + "created_at": f.created_at.isoformat() if f.created_at else None, + "updated_at": f.updated_at.isoformat() if f.updated_at else None, + } + + @router.get( "", operation_id="listMetadataSchemas", summary="List metadata schemas", - description="[AC-IDSMETA-13] 获取元数据字段定义列表,支持按状态和范围过滤", + description="[AC-IDSMETA-13] [AC-MRS-06] 获取元数据字段定义列表,支持按状态、范围、角色过滤", ) async def list_schemas( tenant_id: Annotated[str, Depends(get_current_tenant_id)], @@ -49,28 +74,32 @@ async def list_schemas( scope: Annotated[str | None, Query( description="按适用范围过滤: kb_document/intent_rule/script_flow/prompt_template" )] = None, + field_role: Annotated[str | None, Query( + description="[AC-MRS-06] 按字段角色过滤: resource_filter/slot/prompt_var/routing_signal" + )] = None, include_deprecated: Annotated[bool, Query( description="是否包含已废弃的字段" )] = False, ) -> JSONResponse: """ - [AC-IDSMETA-13] 列出元数据字段定义 + [AC-IDSMETA-13] [AC-MRS-06] 列出元数据字段定义 Args: status: 按状态过滤 scope: 按适用范围过滤 + field_role: [AC-MRS-06] 按字段角色过滤 include_deprecated: 是否包含已废弃的字段(当 status 未指定时生效) """ logger.info( - f"[AC-IDSMETA-13] Listing metadata field definitions: " - f"tenant={tenant_id}, status={status}, scope={scope}, include_deprecated={include_deprecated}" + f"[AC-IDSMETA-13] [AC-MRS-06] Listing metadata field definitions: " + f"tenant={tenant_id}, status={status}, scope={scope}, field_role={field_role}" ) if status and status not in [s.value for s in MetadataFieldStatus]: return JSONResponse( status_code=400, content={ - "code": "INVALID_STATUS", + "error_code": "INVALID_STATUS", "message": f"Invalid status: {status}", "details": { "valid_values": [s.value for s in MetadataFieldStatus] @@ -78,34 +107,29 @@ async def list_schemas( } ) + if field_role and field_role not in VALID_FIELD_ROLES: + return JSONResponse( + status_code=400, + content={ + "error_code": "INVALID_ROLE", + "message": f"Invalid role '{field_role}'. Valid roles are: {', '.join(VALID_FIELD_ROLES)}", + "details": { + "valid_roles": VALID_FIELD_ROLES + } + } + ) + service = MetadataFieldDefinitionService(session) if include_deprecated and not status: fields = await service.get_field_definitions_for_read(tenant_id, scope) + if field_role: + fields = [f for f in fields if field_role in (f.field_roles or [])] else: - fields = await service.list_field_definitions(tenant_id, status, scope) + fields = await service.list_field_definitions(tenant_id, status, scope, field_role) return JSONResponse( - content={ - "items": [ - { - "id": str(f.id), - "field_key": f.field_key, - "label": f.label, - "type": f.type, - "required": f.required, - "options": f.options, - "default": f.default_value, - "scope": f.scope, - "is_filterable": f.is_filterable, - "is_rank_feature": f.is_rank_feature, - "status": f.status, - "created_at": f.created_at.isoformat() if f.created_at else None, - "updated_at": f.updated_at.isoformat() if f.updated_at else None, - } - for f in fields - ] - } + content=[_field_to_dict(f) for f in fields] ) @@ -113,7 +137,7 @@ async def list_schemas( "", operation_id="createMetadataSchema", summary="Create metadata schema", - description="[AC-IDSMETA-13] 创建新的元数据字段定义", + description="[AC-IDSMETA-13] [AC-MRS-01,02,03] 创建新的元数据字段定义,支持 field_roles 多选配置", status_code=201, ) async def create_schema( @@ -122,11 +146,11 @@ async def create_schema( field_create: MetadataFieldDefinitionCreate, ) -> JSONResponse: """ - [AC-IDSMETA-13] 创建元数据字段定义 + [AC-IDSMETA-13] [AC-MRS-01,02,03] 创建元数据字段定义 """ logger.info( - f"[AC-IDSMETA-13] Creating metadata field definition: " - f"tenant={tenant_id}, field_key={field_create.field_key}" + f"[AC-IDSMETA-13] [AC-MRS-01] Creating metadata field definition: " + f"tenant={tenant_id}, field_key={field_create.field_key}, field_roles={field_create.field_roles}" ) service = MetadataFieldDefinitionService(session) @@ -138,36 +162,104 @@ async def create_schema( return JSONResponse( status_code=400, content={ - "code": "VALIDATION_ERROR", + "error_code": "VALIDATION_ERROR", "message": str(e), } ) return JSONResponse( status_code=201, - content={ - "id": str(field.id), - "field_key": field.field_key, - "label": field.label, - "type": field.type, - "required": field.required, - "options": field.options, - "default": field.default_value, - "scope": field.scope, - "is_filterable": field.is_filterable, - "is_rank_feature": field.is_rank_feature, - "status": field.status, - "created_at": field.created_at.isoformat() if field.created_at else None, - "updated_at": field.updated_at.isoformat() if field.updated_at else None, - } + content=_field_to_dict(field) ) +@router.get( + "/by-role", + operation_id="getMetadataSchemasByRole", + summary="Get metadata schemas by role", + description="[AC-MRS-04,05] 按指定角色查询所有包含该角色的活跃字段定义", +) +async def get_schemas_by_role( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + role: Annotated[str, Query( + description="[AC-MRS-04] 字段角色: resource_filter/slot/prompt_var/routing_signal" + )], + include_deprecated: Annotated[bool, Query( + description="是否包含已废弃字段" + )] = False, +) -> JSONResponse: + """ + [AC-MRS-04,05] 按角色查询字段定义 + + Args: + role: 字段角色 + include_deprecated: 是否包含已废弃字段 + """ + logger.info( + f"[AC-MRS-04] Getting metadata schemas by role: " + f"tenant={tenant_id}, role={role}, include_deprecated={include_deprecated}" + ) + + provider = RoleBasedFieldProvider(session) + + try: + fields = await provider.get_fields_by_role(tenant_id, role, include_deprecated) + except InvalidRoleError as e: + return JSONResponse( + status_code=400, + content={ + "error_code": "INVALID_ROLE", + "message": str(e), + "details": { + "valid_roles": e.valid_roles + } + } + ) + + return JSONResponse( + content=[_field_to_dict(f) for f in fields] + ) + + +@router.get( + "/{id}", + operation_id="getMetadataSchema", + summary="Get metadata schema by ID", + description="获取单个元数据字段定义", +) +async def get_schema( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + id: str, +) -> JSONResponse: + """ + 获取单个元数据字段定义 + """ + logger.info( + f"Getting metadata field definition: tenant={tenant_id}, id={id}" + ) + + service = MetadataFieldDefinitionService(session) + field = await service.get_field_definition(tenant_id, id) + + if not field: + return JSONResponse( + status_code=404, + content={ + "error_code": "NOT_FOUND", + "message": f"Field definition {id} not found", + } + ) + + return JSONResponse(content=_field_to_dict(field)) + + @router.put( "/{id}", operation_id="updateMetadataSchema", summary="Update metadata schema", - description="[AC-IDSMETA-14] 更新元数据字段定义,支持状态切换", + description="[AC-IDSMETA-14] [AC-MRS-01] 更新元数据字段定义,支持修改 field_roles", ) async def update_schema( tenant_id: Annotated[str, Depends(get_current_tenant_id)], @@ -176,18 +268,18 @@ async def update_schema( field_update: MetadataFieldDefinitionUpdate, ) -> JSONResponse: """ - [AC-IDSMETA-14] 更新元数据字段定义 + [AC-IDSMETA-14] [AC-MRS-01] 更新元数据字段定义 """ logger.info( - f"[AC-IDSMETA-14] Updating metadata field definition: " - f"tenant={tenant_id}, id={id}" + f"[AC-IDSMETA-14] [AC-MRS-01] Updating metadata field definition: " + f"tenant={tenant_id}, id={id}, field_roles={field_update.field_roles}" ) if field_update.status and field_update.status not in [s.value for s in MetadataFieldStatus]: return JSONResponse( status_code=400, content={ - "code": "INVALID_STATUS", + "error_code": "INVALID_STATUS", "message": f"Invalid status: {field_update.status}", "details": { "valid_values": [s.value for s in MetadataFieldStatus] @@ -203,7 +295,7 @@ async def update_schema( return JSONResponse( status_code=400, content={ - "code": "VALIDATION_ERROR", + "error_code": "VALIDATION_ERROR", "message": str(e), } ) @@ -212,30 +304,49 @@ async def update_schema( return JSONResponse( status_code=404, content={ - "code": "NOT_FOUND", + "error_code": "NOT_FOUND", "message": f"Field definition {id} not found", } ) await session.commit() - return JSONResponse( - content={ - "id": str(field.id), - "field_key": field.field_key, - "label": field.label, - "type": field.type, - "required": field.required, - "options": field.options, - "default": field.default_value, - "scope": field.scope, - "is_filterable": field.is_filterable, - "is_rank_feature": field.is_rank_feature, - "status": field.status, - "created_at": field.created_at.isoformat() if field.created_at else None, - "updated_at": field.updated_at.isoformat() if field.updated_at else None, - } + return JSONResponse(content=_field_to_dict(field)) + + +@router.delete( + "/{id}", + operation_id="deleteMetadataSchema", + summary="Delete metadata schema", + description="[AC-MRS-16] 删除元数据字段定义,无需考虑历史数据兼容性", + status_code=204, +) +async def delete_schema( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + id: str, +) -> JSONResponse: + """ + [AC-MRS-16] 删除元数据字段定义 + """ + logger.info( + f"[AC-MRS-16] Deleting metadata field definition: " + f"tenant={tenant_id}, id={id}" ) + + service = MetadataFieldDefinitionService(session) + success = await service.delete_field_definition(tenant_id, id) + + if not success: + return JSONResponse( + status_code=404, + content={ + "error_code": "NOT_FOUND", + "message": f"Field definition not found: {id}", + } + ) + + return JSONResponse(status_code=204, content=None) @router.get( @@ -263,26 +374,7 @@ async def get_active_schemas( fields = await service.get_active_field_definitions(tenant_id, scope) return JSONResponse( - content={ - "items": [ - { - "id": str(f.id), - "field_key": f.field_key, - "label": f.label, - "type": f.type, - "required": f.required, - "options": f.options, - "default": f.default_value, - "scope": f.scope, - "is_filterable": f.is_filterable, - "is_rank_feature": f.is_rank_feature, - "status": f.status, - "created_at": f.created_at.isoformat() if f.created_at else None, - "updated_at": f.updated_at.isoformat() if f.updated_at else None, - } - for f in fields - ] - } + content=[_field_to_dict(f) for f in fields] ) @@ -311,26 +403,7 @@ async def get_readable_schemas( fields = await service.get_field_definitions_for_read(tenant_id, scope) return JSONResponse( - content={ - "items": [ - { - "id": str(f.id), - "field_key": f.field_key, - "label": f.label, - "type": f.type, - "required": f.required, - "options": f.options, - "default": f.default_value, - "scope": f.scope, - "is_filterable": f.is_filterable, - "is_rank_feature": f.is_rank_feature, - "status": f.status, - "created_at": f.created_at.isoformat() if f.created_at else None, - "updated_at": f.updated_at.isoformat() if f.updated_at else None, - } - for f in fields - ] - } + content=[_field_to_dict(f) for f in fields] ) @@ -373,42 +446,3 @@ async def validate_metadata_for_create( "errors": errors, } ) - - -@router.delete( - "/{field_id}", - operation_id="deleteMetadataSchema", - summary="Delete metadata schema", - description="[AC-IDSMETA-13] 删除元数据字段定义", -) -async def delete_schema( - field_id: str, - tenant_id: Annotated[str, Depends(get_current_tenant_id)], - session: Annotated[AsyncSession, Depends(get_session)], -) -> JSONResponse: - """ - [AC-IDSMETA-13] 删除元数据字段定义 - """ - logger.info( - f"[AC-IDSMETA-13] Deleting metadata field definition: " - f"tenant={tenant_id}, field_id={field_id}" - ) - - service = MetadataFieldDefinitionService(session) - success = await service.delete_field_definition(tenant_id, field_id) - - if not success: - return JSONResponse( - status_code=404, - content={ - "code": "NOT_FOUND", - "message": f"Field definition not found: {field_id}", - } - ) - - return JSONResponse( - content={ - "success": True, - "message": "Field definition deleted successfully", - } - ) diff --git a/ai-service/app/api/admin/slot_definition.py b/ai-service/app/api/admin/slot_definition.py new file mode 100644 index 0000000..090fdda --- /dev/null +++ b/ai-service/app/api/admin/slot_definition.py @@ -0,0 +1,234 @@ +""" +Slot Definition API. +[AC-MRS-07,08,16] 槽位定义管理接口 +""" + +import logging +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, Query +from fastapi.responses import JSONResponse +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_session +from app.core.exceptions import MissingTenantIdException +from app.core.tenant import get_tenant_id +from app.models.entities import SlotDefinitionCreate, SlotDefinitionUpdate +from app.services.slot_definition_service import SlotDefinitionService + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/admin/slot-definitions", tags=["SlotDefinition"]) + + +def get_current_tenant_id() -> str: + """Get current tenant ID from context.""" + tenant_id = get_tenant_id() + if not tenant_id: + raise MissingTenantIdException() + return tenant_id + + +def _slot_to_dict(slot: dict[str, Any] | Any) -> dict[str, Any]: + """Convert slot definition to dict""" + if isinstance(slot, dict): + return slot + + return { + "id": str(slot.id), + "tenant_id": str(slot.tenant_id), + "slot_key": slot.slot_key, + "type": slot.type, + "required": slot.required, + "extract_strategy": slot.extract_strategy, + "validation_rule": slot.validation_rule, + "ask_back_prompt": slot.ask_back_prompt, + "default_value": slot.default_value, + "linked_field_id": str(slot.linked_field_id) if slot.linked_field_id else None, + "created_at": slot.created_at.isoformat() if slot.created_at else None, + "updated_at": slot.updated_at.isoformat() if slot.updated_at else None, + } + + +@router.get( + "", + operation_id="listSlotDefinitions", + summary="List slot definitions", + description="获取槽位定义列表", +) +async def list_slot_definitions( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + required: Annotated[bool | None, Query( + description="按是否必填过滤" + )] = None, +) -> JSONResponse: + """ + 列出槽位定义 + """ + logger.info( + f"Listing slot definitions: tenant={tenant_id}, required={required}" + ) + + service = SlotDefinitionService(session) + slots = await service.list_slot_definitions(tenant_id, required) + + return JSONResponse( + content=[_slot_to_dict(s) for s in slots] + ) + + +@router.post( + "", + operation_id="createSlotDefinition", + summary="Create slot definition", + description="[AC-MRS-07,08] 创建新的槽位定义,可关联已有元数据字段", + status_code=201, +) +async def create_slot_definition( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + slot_create: SlotDefinitionCreate, +) -> JSONResponse: + """ + [AC-MRS-07,08] 创建槽位定义 + """ + logger.info( + f"[AC-MRS-07] Creating slot definition: " + f"tenant={tenant_id}, slot_key={slot_create.slot_key}, " + f"linked_field_id={slot_create.linked_field_id}" + ) + + service = SlotDefinitionService(session) + + try: + slot = await service.create_slot_definition(tenant_id, slot_create) + await session.commit() + except ValueError as e: + return JSONResponse( + status_code=400, + content={ + "error_code": "VALIDATION_ERROR", + "message": str(e), + } + ) + + return JSONResponse( + status_code=201, + content=_slot_to_dict(slot) + ) + + +@router.get( + "/{id}", + operation_id="getSlotDefinition", + summary="Get slot definition by ID", + description="获取单个槽位定义", +) +async def get_slot_definition( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + id: str, +) -> JSONResponse: + """ + 获取单个槽位定义 + """ + logger.info( + f"Getting slot definition: tenant={tenant_id}, id={id}" + ) + + service = SlotDefinitionService(session) + slot = await service.get_slot_definition_with_field(tenant_id, id) + + if not slot: + return JSONResponse( + status_code=404, + content={ + "error_code": "NOT_FOUND", + "message": f"Slot definition {id} not found", + } + ) + + return JSONResponse(content=slot) + + +@router.put( + "/{id}", + operation_id="updateSlotDefinition", + summary="Update slot definition", + description="更新槽位定义", +) +async def update_slot_definition( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + id: str, + slot_update: SlotDefinitionUpdate, +) -> JSONResponse: + """ + 更新槽位定义 + """ + logger.info( + f"Updating slot definition: tenant={tenant_id}, id={id}" + ) + + service = SlotDefinitionService(session) + + try: + slot = await service.update_slot_definition(tenant_id, id, slot_update) + except ValueError as e: + return JSONResponse( + status_code=400, + content={ + "error_code": "VALIDATION_ERROR", + "message": str(e), + } + ) + + if not slot: + return JSONResponse( + status_code=404, + content={ + "error_code": "NOT_FOUND", + "message": f"Slot definition {id} not found", + } + ) + + await session.commit() + + return JSONResponse(content=_slot_to_dict(slot)) + + +@router.delete( + "/{id}", + operation_id="deleteSlotDefinition", + summary="Delete slot definition", + description="[AC-MRS-16] 删除槽位定义,无需考虑历史数据兼容性", + status_code=204, +) +async def delete_slot_definition( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + id: str, +) -> JSONResponse: + """ + [AC-MRS-16] 删除槽位定义 + """ + logger.info( + f"[AC-MRS-16] Deleting slot definition: tenant={tenant_id}, id={id}" + ) + + service = SlotDefinitionService(session) + success = await service.delete_slot_definition(tenant_id, id) + + if not success: + return JSONResponse( + status_code=404, + content={ + "error_code": "NOT_FOUND", + "message": f"Slot definition not found: {id}", + } + ) + + await session.commit() + + return JSONResponse(status_code=204, content=None) diff --git a/ai-service/app/api/mid/slots.py b/ai-service/app/api/mid/slots.py new file mode 100644 index 0000000..eeca453 --- /dev/null +++ b/ai-service/app/api/mid/slots.py @@ -0,0 +1,140 @@ +""" +Runtime Slot API. +[AC-MRS-09,10] 运行时槽位查询接口 +""" + +import logging +from datetime import datetime +from typing import Annotated, Any + +from fastapi import APIRouter, Depends, Query +from fastapi.responses import JSONResponse +from sqlalchemy.ext.asyncio import AsyncSession + +from app.core.database import get_session +from app.core.exceptions import MissingTenantIdException +from app.core.tenant import get_tenant_id +from app.services.mid.role_based_field_provider import RoleBasedFieldProvider, InvalidRoleError +from app.services.slot_definition_service import SlotDefinitionService +from app.schemas.metadata import VALID_FIELD_ROLES + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/mid/slots", tags=["RuntimeSlot"]) + + +def get_current_tenant_id() -> str: + """Get current tenant ID from context.""" + tenant_id = get_tenant_id() + if not tenant_id: + raise MissingTenantIdException() + return tenant_id + + +@router.get( + "/by-role", + operation_id="getSlotsByRole", + summary="Get slots by role", + description="[AC-MRS-10] 运行时接口,按角色获取槽位定义及关联的元数据字段信息", +) +async def get_slots_by_role( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + role: Annotated[str, Query( + description="[AC-MRS-10] 字段角色: resource_filter/slot/prompt_var/routing_signal" + )] = "slot", +) -> JSONResponse: + """ + [AC-MRS-10] 按角色获取槽位定义 + + Args: + role: 字段角色,默认为 slot + """ + logger.info( + f"[AC-MRS-10] Getting slots by role: tenant={tenant_id}, role={role}" + ) + + provider = RoleBasedFieldProvider(session) + + try: + slots = await provider.get_slot_definitions_by_role(tenant_id, role) + except InvalidRoleError as e: + return JSONResponse( + status_code=400, + content={ + "error_code": "INVALID_ROLE", + "message": str(e), + "details": { + "valid_roles": e.valid_roles + } + } + ) + + return JSONResponse(content=slots) + + +@router.get( + "/{slot_key}", + operation_id="getSlotValue", + summary="Get runtime slot value", + description="[AC-MRS-09] 获取指定槽位的运行时值,包含来源、置信度、更新时间", +) +async def get_slot_value( + tenant_id: Annotated[str, Depends(get_current_tenant_id)], + session: Annotated[AsyncSession, Depends(get_session)], + slot_key: str, + user_id: Annotated[str | None, Query( + description="用户 ID" + )] = None, + session_id: Annotated[str | None, Query( + description="会话 ID" + )] = None, +) -> JSONResponse: + """ + [AC-MRS-09] 获取运行时槽位值 + + Args: + slot_key: 槽位键名 + user_id: 用户 ID + session_id: 会话 ID + """ + logger.info( + f"[AC-MRS-09] Getting slot value: tenant={tenant_id}, slot_key={slot_key}, " + f"user_id={user_id}, session_id={session_id}" + ) + + service = SlotDefinitionService(session) + slot_def = await service.get_slot_definition_by_key(tenant_id, slot_key) + + if not slot_def: + return JSONResponse( + status_code=404, + content={ + "error_code": "NOT_FOUND", + "message": f"Slot '{slot_key}' not found", + } + ) + + value = slot_def.default_value + source = "default" + confidence = 1.0 + + if value is None: + if slot_def.type == "string": + value = "" + elif slot_def.type == "number": + value = 0 + elif slot_def.type == "boolean": + value = False + elif slot_def.type in ["enum", "array_enum"]: + value = [] if slot_def.type == "array_enum" else "" + + return JSONResponse( + content={ + "key": slot_key, + "value": value, + "source": source, + "confidence": confidence, + "updated_at": datetime.utcnow().isoformat(), + } + ) diff --git a/ai-service/app/main.py b/ai-service/app/main.py index c0bc5c2..54fd2f8 100644 --- a/ai-service/app/main.py +++ b/ai-service/app/main.py @@ -12,6 +12,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from app.api import chat_router, health_router +from app.api.mid import router as mid_router from app.api.admin import ( api_key_router, dashboard_router, @@ -29,6 +30,7 @@ from app.api.admin import ( rag_router, script_flows_router, sessions_router, + slot_definition_router, tenants_router, ) from app.api.admin.kb_optimized import router as kb_optimized_router @@ -165,8 +167,11 @@ app.include_router(prompt_templates_router) app.include_router(rag_router) app.include_router(script_flows_router) app.include_router(sessions_router) +app.include_router(slot_definition_router) app.include_router(tenants_router) +app.include_router(mid_router) + if __name__ == "__main__": import uvicorn diff --git a/ai-service/app/services/slot_definition_service.py b/ai-service/app/services/slot_definition_service.py new file mode 100644 index 0000000..cf6f2d5 --- /dev/null +++ b/ai-service/app/services/slot_definition_service.py @@ -0,0 +1,364 @@ +""" +Slot Definition Service. +[AC-MRS-07, AC-MRS-08] 槽位定义管理服务 +""" + +import logging +import re +import uuid +from datetime import datetime +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.entities import ( + SlotDefinition, + SlotDefinitionCreate, + SlotDefinitionUpdate, + MetadataFieldDefinition, + MetadataFieldType, + ExtractStrategy, +) + +logger = logging.getLogger(__name__) + + +class SlotDefinitionService: + """ + [AC-MRS-07, AC-MRS-08] 槽位定义服务 + + 管理独立的槽位定义模型,与元数据字段解耦但可复用 + """ + + SLOT_KEY_PATTERN = re.compile(r"^[a-z][a-z0-9_]*$") + VALID_TYPES = ["string", "number", "boolean", "enum", "array_enum"] + VALID_EXTRACT_STRATEGIES = ["rule", "llm", "user_input"] + + def __init__(self, session: AsyncSession): + self._session = session + + async def list_slot_definitions( + self, + tenant_id: str, + required: bool | None = None, + ) -> list[SlotDefinition]: + """ + 列出租户所有槽位定义 + + Args: + tenant_id: 租户 ID + required: 按是否必填过滤 + + Returns: + SlotDefinition 列表 + """ + stmt = select(SlotDefinition).where( + SlotDefinition.tenant_id == tenant_id, + ) + + if required is not None: + stmt = stmt.where(SlotDefinition.required == required) + + stmt = stmt.order_by(SlotDefinition.created_at.desc()) + + result = await self._session.execute(stmt) + return list(result.scalars().all()) + + async def get_slot_definition( + self, + tenant_id: str, + slot_id: str, + ) -> SlotDefinition | None: + """ + 获取单个槽位定义 + + Args: + tenant_id: 租户 ID + slot_id: 槽位定义 ID + + Returns: + SlotDefinition 或 None + """ + try: + slot_uuid = uuid.UUID(slot_id) + except ValueError: + return None + + stmt = select(SlotDefinition).where( + SlotDefinition.tenant_id == tenant_id, + SlotDefinition.id == slot_uuid, + ) + result = await self._session.execute(stmt) + return result.scalar_one_or_none() + + async def get_slot_definition_by_key( + self, + tenant_id: str, + slot_key: str, + ) -> SlotDefinition | None: + """ + 通过 slot_key 获取槽位定义 + + Args: + tenant_id: 租户 ID + slot_key: 槽位键名 + + Returns: + SlotDefinition 或 None + """ + stmt = select(SlotDefinition).where( + SlotDefinition.tenant_id == tenant_id, + SlotDefinition.slot_key == slot_key, + ) + result = await self._session.execute(stmt) + return result.scalar_one_or_none() + + async def create_slot_definition( + self, + tenant_id: str, + slot_create: SlotDefinitionCreate, + ) -> SlotDefinition: + """ + [AC-MRS-07, AC-MRS-08] 创建槽位定义 + + Args: + tenant_id: 租户 ID + slot_create: 创建数据 + + Returns: + 创建的 SlotDefinition + + Raises: + ValueError: 如果 slot_key 已存在或参数无效 + """ + if not self.SLOT_KEY_PATTERN.match(slot_create.slot_key): + raise ValueError( + f"slot_key '{slot_create.slot_key}' 格式不正确," + "必须以小写字母开头,仅允许小写字母、数字和下划线" + ) + + existing = await self.get_slot_definition_by_key(tenant_id, slot_create.slot_key) + if existing: + raise ValueError(f"slot_key '{slot_create.slot_key}' 已存在") + + if slot_create.type not in self.VALID_TYPES: + raise ValueError( + f"无效的槽位类型 '{slot_create.type}'," + f"有效类型为: {self.VALID_TYPES}" + ) + + if slot_create.extract_strategy and slot_create.extract_strategy not in self.VALID_EXTRACT_STRATEGIES: + raise ValueError( + f"无效的提取策略 '{slot_create.extract_strategy}'," + f"有效策略为: {self.VALID_EXTRACT_STRATEGIES}" + ) + + linked_field = None + if slot_create.linked_field_id: + linked_field = await self._get_linked_field(slot_create.linked_field_id) + if not linked_field: + raise ValueError( + f"[AC-MRS-08] 关联的元数据字段 '{slot_create.linked_field_id}' 不存在" + ) + + slot = SlotDefinition( + tenant_id=tenant_id, + slot_key=slot_create.slot_key, + type=slot_create.type, + required=slot_create.required, + extract_strategy=slot_create.extract_strategy, + validation_rule=slot_create.validation_rule, + ask_back_prompt=slot_create.ask_back_prompt, + default_value=slot_create.default_value, + linked_field_id=uuid.UUID(slot_create.linked_field_id) if slot_create.linked_field_id else None, + ) + + self._session.add(slot) + await self._session.flush() + + logger.info( + f"[AC-MRS-07] Created slot definition: tenant={tenant_id}, " + f"slot_key={slot.slot_key}, required={slot.required}, " + f"linked_field_id={slot.linked_field_id}" + ) + + return slot + + async def update_slot_definition( + self, + tenant_id: str, + slot_id: str, + slot_update: SlotDefinitionUpdate, + ) -> SlotDefinition | None: + """ + 更新槽位定义 + + Args: + tenant_id: 租户 ID + slot_id: 槽位定义 ID + slot_update: 更新数据 + + Returns: + 更新后的 SlotDefinition 或 None + """ + slot = await self.get_slot_definition(tenant_id, slot_id) + if not slot: + return None + + if slot_update.type is not None: + if slot_update.type not in self.VALID_TYPES: + raise ValueError( + f"无效的槽位类型 '{slot_update.type}'," + f"有效类型为: {self.VALID_TYPES}" + ) + slot.type = slot_update.type + + if slot_update.required is not None: + slot.required = slot_update.required + + if slot_update.extract_strategy is not None: + if slot_update.extract_strategy not in self.VALID_EXTRACT_STRATEGIES: + raise ValueError( + f"无效的提取策略 '{slot_update.extract_strategy}'," + f"有效策略为: {self.VALID_EXTRACT_STRATEGIES}" + ) + slot.extract_strategy = slot_update.extract_strategy + + if slot_update.validation_rule is not None: + slot.validation_rule = slot_update.validation_rule + + if slot_update.ask_back_prompt is not None: + slot.ask_back_prompt = slot_update.ask_back_prompt + + if slot_update.default_value is not None: + slot.default_value = slot_update.default_value + + if slot_update.linked_field_id is not None: + if slot_update.linked_field_id: + linked_field = await self._get_linked_field(slot_update.linked_field_id) + if not linked_field: + raise ValueError( + f"[AC-MRS-08] 关联的元数据字段 '{slot_update.linked_field_id}' 不存在" + ) + slot.linked_field_id = uuid.UUID(slot_update.linked_field_id) + else: + slot.linked_field_id = None + + slot.updated_at = datetime.utcnow() + await self._session.flush() + + logger.info( + f"[AC-MRS-07] Updated slot definition: tenant={tenant_id}, " + f"slot_id={slot_id}" + ) + + return slot + + async def delete_slot_definition( + self, + tenant_id: str, + slot_id: str, + ) -> bool: + """ + [AC-MRS-16] 删除槽位定义 + + Args: + tenant_id: 租户 ID + slot_id: 槽位定义 ID + + Returns: + 是否删除成功 + """ + slot = await self.get_slot_definition(tenant_id, slot_id) + if not slot: + return False + + await self._session.delete(slot) + await self._session.flush() + + logger.info( + f"[AC-MRS-16] Deleted slot definition: tenant={tenant_id}, " + f"slot_id={slot_id}" + ) + + return True + + async def _get_linked_field( + self, + field_id: str, + ) -> MetadataFieldDefinition | None: + """ + 获取关联的元数据字段 + + Args: + field_id: 字段 ID + + Returns: + MetadataFieldDefinition 或 None + """ + try: + field_uuid = uuid.UUID(field_id) + except ValueError: + return None + + stmt = select(MetadataFieldDefinition).where( + MetadataFieldDefinition.id == field_uuid, + ) + result = await self._session.execute(stmt) + return result.scalar_one_or_none() + + async def get_slot_definition_with_field( + self, + tenant_id: str, + slot_id: str, + ) -> dict[str, Any] | None: + """ + 获取槽位定义及其关联字段信息 + + Args: + tenant_id: 租户 ID + slot_id: 槽位定义 ID + + Returns: + 包含槽位定义和关联字段的字典 + """ + slot = await self.get_slot_definition(tenant_id, slot_id) + if not slot: + return None + + result = { + "id": str(slot.id), + "tenant_id": slot.tenant_id, + "slot_key": slot.slot_key, + "type": slot.type, + "required": slot.required, + "extract_strategy": slot.extract_strategy, + "validation_rule": slot.validation_rule, + "ask_back_prompt": slot.ask_back_prompt, + "default_value": slot.default_value, + "linked_field_id": str(slot.linked_field_id) if slot.linked_field_id else None, + "created_at": slot.created_at.isoformat() if slot.created_at else None, + "updated_at": slot.updated_at.isoformat() if slot.updated_at else None, + "linked_field": None, + } + + if slot.linked_field_id: + linked_field = await self._get_linked_field(str(slot.linked_field_id)) + if linked_field: + result["linked_field"] = { + "id": str(linked_field.id), + "field_key": linked_field.field_key, + "label": linked_field.label, + "type": linked_field.type, + "required": linked_field.required, + "options": linked_field.options, + "default_value": linked_field.default_value, + "scope": linked_field.scope, + "is_filterable": linked_field.is_filterable, + "is_rank_feature": linked_field.is_rank_feature, + "field_roles": linked_field.field_roles, + "status": linked_field.status, + } + + return result diff --git a/ai-service/tests/test_role_based_field_provider.py b/ai-service/tests/test_role_based_field_provider.py new file mode 100644 index 0000000..a5458b5 --- /dev/null +++ b/ai-service/tests/test_role_based_field_provider.py @@ -0,0 +1,226 @@ +""" +Unit tests for RoleBasedFieldProvider service. +[AC-MRS-04,05,10,11,12,13,14] 验证按角色查询字段定义功能 +""" + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch +from sqlalchemy.ext.asyncio import AsyncSession + +from app.services.mid.role_based_field_provider import ( + RoleBasedFieldProvider, + InvalidRoleError, +) +from app.models.entities import ( + MetadataFieldDefinition, + MetadataFieldStatus, + SlotDefinition, +) +from app.schemas.metadata import VALID_FIELD_ROLES + + +class TestRoleBasedFieldProvider: + """[AC-MRS-04,05,10] RoleBasedFieldProvider 测试""" + + @pytest.fixture + def mock_session(self): + """Mock AsyncSession""" + session = MagicMock(spec=AsyncSession) + session.execute = AsyncMock() + return session + + @pytest.fixture + def provider(self, mock_session): + """Create provider instance""" + return RoleBasedFieldProvider(mock_session) + + def test_validate_role_valid(self, provider): + """[AC-MRS-04] 验证有效角色""" + for role in VALID_FIELD_ROLES: + result = provider._validate_role(role) + assert result == role + + def test_validate_role_invalid(self, provider): + """[AC-MRS-05] 验证无效角色抛出异常""" + with pytest.raises(InvalidRoleError) as exc_info: + provider._validate_role("invalid_role") + + assert "Invalid role 'invalid_role'" in str(exc_info.value) + assert exc_info.value.valid_roles == VALID_FIELD_ROLES + + @pytest.mark.asyncio + async def test_get_fields_by_role(self, provider, mock_session): + """[AC-MRS-04] 按角色获取字段定义""" + mock_field = MagicMock(spec=MetadataFieldDefinition) + mock_field.id = "test-id" + mock_field.field_key = "grade" + mock_field.label = "年级" + mock_field.field_roles = ["resource_filter", "slot"] + mock_field.status = MetadataFieldStatus.ACTIVE.value + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_field] + mock_session.execute.return_value = mock_result + + fields = await provider.get_fields_by_role( + "test-tenant", + "resource_filter" + ) + + assert len(fields) == 1 + assert fields[0].field_key == "grade" + assert "resource_filter" in fields[0].field_roles + + @pytest.mark.asyncio + async def test_get_fields_by_role_invalid_role(self, provider): + """[AC-MRS-05] 无效角色返回 400 错误""" + with pytest.raises(InvalidRoleError): + await provider.get_fields_by_role( + "test-tenant", + "invalid_role" + ) + + @pytest.mark.asyncio + async def test_get_fields_by_role_include_deprecated(self, provider, mock_session): + """[AC-MRS-04] 包含已废弃字段""" + mock_active = MagicMock(spec=MetadataFieldDefinition) + mock_active.field_key = "active_field" + mock_active.status = MetadataFieldStatus.ACTIVE.value + + mock_deprecated = MagicMock(spec=MetadataFieldDefinition) + mock_deprecated.field_key = "deprecated_field" + mock_deprecated.status = MetadataFieldStatus.DEPRECATED.value + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_active, mock_deprecated] + mock_session.execute.return_value = mock_result + + fields = await provider.get_fields_by_role( + "test-tenant", + "resource_filter", + include_deprecated=True + ) + + assert len(fields) == 2 + + @pytest.mark.asyncio + async def test_get_slot_definitions_by_role(self, provider, mock_session): + """[AC-MRS-10] 按角色获取槽位定义""" + mock_field = MagicMock(spec=MetadataFieldDefinition) + mock_field.id = MagicMock() + mock_field.id.__str__ = lambda self: "field-id-123" + mock_field.field_key = "grade" + mock_field.label = "年级" + mock_field.type = "string" + mock_field.required = True + mock_field.options = None + mock_field.default_value = None + mock_field.scope = ["kb_document"] + mock_field.is_filterable = True + mock_field.is_rank_feature = False + mock_field.field_roles = ["slot"] + mock_field.status = MetadataFieldStatus.ACTIVE.value + + mock_slot = MagicMock(spec=SlotDefinition) + mock_slot.id = MagicMock() + mock_slot.id.__str__ = lambda self: "slot-id-456" + mock_slot.tenant_id = "test-tenant" + mock_slot.slot_key = "grade" + mock_slot.type = "string" + mock_slot.required = True + mock_slot.extract_strategy = "llm" + mock_slot.validation_rule = None + mock_slot.ask_back_prompt = "请输入年级" + mock_slot.default_value = None + mock_slot.linked_field_id = mock_field.id + mock_slot.created_at = None + mock_slot.updated_at = None + + field_result = MagicMock() + field_result.scalars.return_value.all.return_value = [mock_field] + + slot_result = MagicMock() + slot_result.scalars.return_value.all.return_value = [mock_slot] + + mock_session.execute.side_effect = [field_result, slot_result] + + slots = await provider.get_slot_definitions_by_role("test-tenant", "slot") + + assert len(slots) >= 1 + + @pytest.mark.asyncio + async def test_get_resource_filter_fields(self, provider, mock_session): + """[AC-MRS-11] 获取资源过滤角色字段""" + mock_field = MagicMock(spec=MetadataFieldDefinition) + mock_field.field_key = "category" + mock_field.field_roles = ["resource_filter"] + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_field] + mock_session.execute.return_value = mock_result + + fields = await provider.get_resource_filter_fields("test-tenant") + + assert len(fields) == 1 + assert "resource_filter" in fields[0].field_roles + + @pytest.mark.asyncio + async def test_get_slot_fields(self, provider, mock_session): + """[AC-MRS-12] 获取槽位角色字段""" + mock_field = MagicMock(spec=MetadataFieldDefinition) + mock_field.field_key = "user_name" + mock_field.field_roles = ["slot"] + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_field] + mock_session.execute.return_value = mock_result + + fields = await provider.get_slot_fields("test-tenant") + + assert len(fields) == 1 + assert "slot" in fields[0].field_roles + + @pytest.mark.asyncio + async def test_get_routing_signal_fields(self, provider, mock_session): + """[AC-MRS-13] 获取路由信号角色字段""" + mock_field = MagicMock(spec=MetadataFieldDefinition) + mock_field.field_key = "priority" + mock_field.field_roles = ["routing_signal"] + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_field] + mock_session.execute.return_value = mock_result + + fields = await provider.get_routing_signal_fields("test-tenant") + + assert len(fields) == 1 + assert "routing_signal" in fields[0].field_roles + + @pytest.mark.asyncio + async def test_get_prompt_var_fields(self, provider, mock_session): + """[AC-MRS-14] 获取提示词变量角色字段""" + mock_field = MagicMock(spec=MetadataFieldDefinition) + mock_field.field_key = "user_name" + mock_field.field_roles = ["prompt_var"] + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_field] + mock_session.execute.return_value = mock_result + + fields = await provider.get_prompt_var_fields("test-tenant") + + assert len(fields) == 1 + assert "prompt_var" in fields[0].field_roles + + +class TestInvalidRoleError: + """[AC-MRS-05] InvalidRoleError 测试""" + + def test_error_message(self): + """验证错误消息格式""" + error = InvalidRoleError("bad_role") + + assert error.role == "bad_role" + assert error.valid_roles == VALID_FIELD_ROLES + assert "Invalid role 'bad_role'" in str(error) + assert "resource_filter" in str(error) diff --git a/ai-service/tests/test_slot_definition_service.py b/ai-service/tests/test_slot_definition_service.py new file mode 100644 index 0000000..b5c8ad2 --- /dev/null +++ b/ai-service/tests/test_slot_definition_service.py @@ -0,0 +1,333 @@ +""" +Unit tests for SlotDefinitionService. +[AC-MRS-07,08,16] 验证槽位定义管理功能 +""" + +import uuid +import pytest +from unittest.mock import AsyncMock, MagicMock +from sqlalchemy.ext.asyncio import AsyncSession + +from app.services.slot_definition_service import SlotDefinitionService +from app.models.entities import ( + SlotDefinition, + SlotDefinitionCreate, + SlotDefinitionUpdate, + MetadataFieldDefinition, +) + + +class TestSlotDefinitionService: + """[AC-MRS-07,08,16] SlotDefinitionService 测试""" + + @pytest.fixture + def mock_session(self): + """Mock AsyncSession""" + session = MagicMock(spec=AsyncSession) + session.execute = AsyncMock() + session.add = MagicMock() + session.flush = AsyncMock() + session.delete = AsyncMock() + return session + + @pytest.fixture + def service(self, mock_session): + """Create service instance""" + return SlotDefinitionService(mock_session) + + @pytest.mark.asyncio + async def test_list_slot_definitions(self, service, mock_session): + """列出槽位定义""" + mock_slot = MagicMock(spec=SlotDefinition) + mock_slot.id = uuid.uuid4() + mock_slot.slot_key = "grade" + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_slot] + mock_session.execute.return_value = mock_result + + slots = await service.list_slot_definitions("test-tenant") + + assert len(slots) == 1 + assert slots[0].slot_key == "grade" + + @pytest.mark.asyncio + async def test_list_slot_definitions_filter_required(self, service, mock_session): + """按必填过滤槽位定义""" + mock_slot = MagicMock(spec=SlotDefinition) + mock_slot.required = True + + mock_result = MagicMock() + mock_result.scalars.return_value.all.return_value = [mock_slot] + mock_session.execute.return_value = mock_result + + slots = await service.list_slot_definitions("test-tenant", required=True) + + assert len(slots) == 1 + assert slots[0].required is True + + @pytest.mark.asyncio + async def test_get_slot_definition(self, service, mock_session): + """获取单个槽位定义""" + slot_id = uuid.uuid4() + mock_slot = MagicMock(spec=SlotDefinition) + mock_slot.id = slot_id + mock_slot.slot_key = "grade" + + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = mock_slot + mock_session.execute.return_value = mock_result + + slot = await service.get_slot_definition("test-tenant", str(slot_id)) + + assert slot is not None + assert slot.slot_key == "grade" + + @pytest.mark.asyncio + async def test_get_slot_definition_not_found(self, service, mock_session): + """获取不存在的槽位定义""" + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = None + mock_session.execute.return_value = mock_result + + slot = await service.get_slot_definition("test-tenant", str(uuid.uuid4())) + + assert slot is None + + @pytest.mark.asyncio + async def test_get_slot_definition_by_key(self, service, mock_session): + """通过 slot_key 获取槽位定义""" + mock_slot = MagicMock(spec=SlotDefinition) + mock_slot.slot_key = "grade" + + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = mock_slot + mock_session.execute.return_value = mock_result + + slot = await service.get_slot_definition_by_key("test-tenant", "grade") + + assert slot is not None + assert slot.slot_key == "grade" + + @pytest.mark.asyncio + async def test_create_slot_definition(self, service, mock_session): + """[AC-MRS-07] 创建槽位定义""" + slot_create = SlotDefinitionCreate( + slot_key="grade", + type="string", + required=True, + extract_strategy="llm", + ask_back_prompt="请输入年级", + ) + + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = None + mock_session.execute.return_value = mock_result + + slot = await service.create_slot_definition("test-tenant", slot_create) + + assert slot is not None + mock_session.add.assert_called_once() + mock_session.flush.assert_called_once() + + @pytest.mark.asyncio + async def test_create_slot_definition_invalid_key(self, service): + """[AC-MRS-07] 创建无效 slot_key 抛出异常""" + slot_create = SlotDefinitionCreate( + slot_key="InvalidKey", + type="string", + required=True, + ) + + with pytest.raises(ValueError) as exc_info: + await service.create_slot_definition("test-tenant", slot_create) + + assert "格式不正确" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_create_slot_definition_duplicate_key(self, service, mock_session): + """[AC-MRS-07] 创建重复 slot_key 抛出异常""" + existing_slot = MagicMock(spec=SlotDefinition) + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = existing_slot + mock_session.execute.return_value = mock_result + + slot_create = SlotDefinitionCreate( + slot_key="grade", + type="string", + required=True, + ) + + with pytest.raises(ValueError) as exc_info: + await service.create_slot_definition("test-tenant", slot_create) + + assert "已存在" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_create_slot_definition_invalid_type(self, service, mock_session): + """[AC-MRS-07] 创建无效类型抛出异常""" + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = None + mock_session.execute.return_value = mock_result + + slot_create = SlotDefinitionCreate( + slot_key="grade", + type="invalid_type", + required=True, + ) + + with pytest.raises(ValueError) as exc_info: + await service.create_slot_definition("test-tenant", slot_create) + + assert "无效的槽位类型" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_create_slot_definition_with_linked_field(self, service, mock_session): + """[AC-MRS-08] 创建槽位定义并关联元数据字段""" + field_id = uuid.uuid4() + mock_field = MagicMock(spec=MetadataFieldDefinition) + mock_field.id = field_id + + slot_result = MagicMock() + slot_result.scalar_one_or_none.return_value = None + + field_result = MagicMock() + field_result.scalar_one_or_none.return_value = mock_field + + mock_session.execute.side_effect = [slot_result, field_result] + + slot_create = SlotDefinitionCreate( + slot_key="grade", + type="string", + required=True, + linked_field_id=str(field_id), + ) + + slot = await service.create_slot_definition("test-tenant", slot_create) + + assert slot is not None + mock_session.add.assert_called_once() + + @pytest.mark.asyncio + async def test_create_slot_definition_linked_field_not_found(self, service, mock_session): + """[AC-MRS-08] 关联字段不存在抛出异常""" + field_id = uuid.uuid4() + + slot_result = MagicMock() + slot_result.scalar_one_or_none.return_value = None + + field_result = MagicMock() + field_result.scalar_one_or_none.return_value = None + + mock_session.execute.side_effect = [slot_result, field_result] + + slot_create = SlotDefinitionCreate( + slot_key="grade", + type="string", + required=True, + linked_field_id=str(field_id), + ) + + with pytest.raises(ValueError) as exc_info: + await service.create_slot_definition("test-tenant", slot_create) + + assert "关联的元数据字段" in str(exc_info.value) + + @pytest.mark.asyncio + async def test_update_slot_definition(self, service, mock_session): + """更新槽位定义""" + slot_id = uuid.uuid4() + mock_slot = MagicMock(spec=SlotDefinition) + mock_slot.id = slot_id + mock_slot.slot_key = "grade" + mock_slot.type = "string" + mock_slot.required = False + mock_slot.extract_strategy = None + mock_slot.validation_rule = None + mock_slot.ask_back_prompt = None + mock_slot.default_value = None + mock_slot.linked_field_id = None + mock_slot.updated_at = None + + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = mock_slot + mock_session.execute.return_value = mock_result + + slot_update = SlotDefinitionUpdate( + required=True, + ask_back_prompt="请输入年级", + ) + + slot = await service.update_slot_definition("test-tenant", str(slot_id), slot_update) + + assert slot is not None + mock_session.flush.assert_called_once() + + @pytest.mark.asyncio + async def test_update_slot_definition_not_found(self, service, mock_session): + """更新不存在的槽位定义""" + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = None + mock_session.execute.return_value = mock_result + + slot_update = SlotDefinitionUpdate(required=True) + + slot = await service.update_slot_definition("test-tenant", str(uuid.uuid4()), slot_update) + + assert slot is None + + @pytest.mark.asyncio + async def test_delete_slot_definition(self, service, mock_session): + """[AC-MRS-16] 删除槽位定义""" + slot_id = uuid.uuid4() + mock_slot = MagicMock(spec=SlotDefinition) + mock_slot.id = slot_id + + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = mock_slot + mock_session.execute.return_value = mock_result + + success = await service.delete_slot_definition("test-tenant", str(slot_id)) + + assert success is True + mock_session.delete.assert_called_once() + mock_session.flush.assert_called_once() + + @pytest.mark.asyncio + async def test_delete_slot_definition_not_found(self, service, mock_session): + """[AC-MRS-16] 删除不存在的槽位定义""" + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = None + mock_session.execute.return_value = mock_result + + success = await service.delete_slot_definition("test-tenant", str(uuid.uuid4())) + + assert success is False + + @pytest.mark.asyncio + async def test_get_slot_definition_with_field(self, service, mock_session): + """获取槽位定义及关联字段信息""" + slot_id = uuid.uuid4() + mock_slot = MagicMock(spec=SlotDefinition) + mock_slot.id = slot_id + mock_slot.tenant_id = "test-tenant" + mock_slot.slot_key = "grade" + mock_slot.type = "string" + mock_slot.required = True + mock_slot.extract_strategy = "llm" + mock_slot.validation_rule = None + mock_slot.ask_back_prompt = "请输入年级" + mock_slot.default_value = None + mock_slot.linked_field_id = None + mock_slot.created_at = None + mock_slot.updated_at = None + + mock_result = MagicMock() + mock_result.scalar_one_or_none.return_value = mock_slot + mock_session.execute.return_value = mock_result + + result = await service.get_slot_definition_with_field("test-tenant", str(slot_id)) + + assert result is not None + assert result["slot_key"] == "grade" + assert result["linked_field"] is None