feat: implement RoleBasedFieldProvider service for role-based field queries [AC-MRS-04,05,10]

This commit is contained in:
MerCry 2026-03-05 17:12:06 +08:00
parent 68e5adaa28
commit 0db2971c73
1 changed files with 299 additions and 0 deletions

View File

@ -0,0 +1,299 @@
"""
Role Based Field Provider Service.
[AC-MRS-04, AC-MRS-05, AC-MRS-10] 基于角色的字段提供者服务
"""
import logging
from typing import Any
from sqlalchemy import select, cast
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.entities import (
FieldRole,
MetadataFieldDefinition,
MetadataFieldStatus,
SlotDefinition,
)
from app.schemas.metadata import VALID_FIELD_ROLES
logger = logging.getLogger(__name__)
class InvalidRoleError(Exception):
"""[AC-MRS-05] 无效角色异常"""
def __init__(self, role: str):
self.role = role
self.valid_roles = VALID_FIELD_ROLES
super().__init__(
f"Invalid role '{role}'. Valid roles are: {', '.join(self.valid_roles)}"
)
class RoleBasedFieldProvider:
"""
[AC-MRS-04, AC-MRS-05, AC-MRS-10] 基于角色的字段提供者
提供按角色查询字段定义的能力供工具链按需消费
"""
CACHE_KEY_PREFIX = "field_roles"
CACHE_TTL = 300
def __init__(self, session: AsyncSession):
self._session = session
def _validate_role(self, role: str) -> str:
"""
[AC-MRS-05] 验证角色值是否有效
Args:
role: 角色字符串
Returns:
验证通过的角色字符串
Raises:
InvalidRoleError: 如果角色无效
"""
if role not in VALID_FIELD_ROLES:
raise InvalidRoleError(role)
return role
async def get_fields_by_role(
self,
tenant_id: str,
role: str,
include_deprecated: bool = False,
) -> list[MetadataFieldDefinition]:
"""
[AC-MRS-04] 按角色获取字段定义
Args:
tenant_id: 租户 ID
role: 字段角色 (resource_filter/slot/prompt_var/routing_signal)
include_deprecated: 是否包含已废弃字段
Returns:
字段定义列表
Raises:
InvalidRoleError: 如果角色无效
"""
self._validate_role(role)
logger.info(
f"[AC-MRS-04] Getting fields by role: tenant={tenant_id}, "
f"role={role}, include_deprecated={include_deprecated}"
)
stmt = select(MetadataFieldDefinition).where(
MetadataFieldDefinition.tenant_id == tenant_id,
cast(MetadataFieldDefinition.field_roles, JSONB).op('?')(role),
)
if not include_deprecated:
stmt = stmt.where(
MetadataFieldDefinition.status == MetadataFieldStatus.ACTIVE.value
)
else:
stmt = stmt.where(
MetadataFieldDefinition.status.in_([
MetadataFieldStatus.ACTIVE.value,
MetadataFieldStatus.DEPRECATED.value,
])
)
result = await self._session.execute(stmt)
fields = list(result.scalars().all())
logger.info(
f"[AC-MRS-04] Found {len(fields)} fields for role={role}"
)
return fields
async def get_field_keys_by_role(
self,
tenant_id: str,
role: str,
include_deprecated: bool = False,
) -> list[str]:
"""
按角色获取字段键名列表
Args:
tenant_id: 租户 ID
role: 字段角色
include_deprecated: 是否包含已废弃字段
Returns:
字段键名列表
"""
fields = await self.get_fields_by_role(tenant_id, role, include_deprecated)
return [f.field_key for f in fields]
async def get_slot_definitions_by_role(
self,
tenant_id: str,
role: str = "slot",
) -> list[dict[str, Any]]:
"""
[AC-MRS-10] 按角色获取槽位定义及关联字段信息
Args:
tenant_id: 租户 ID
role: 字段角色默认为 slot
Returns:
槽位定义列表包含关联字段信息
"""
self._validate_role(role)
logger.info(
f"[AC-MRS-10] Getting slot definitions by role: tenant={tenant_id}, role={role}"
)
fields = await self.get_fields_by_role(tenant_id, role)
field_ids = [f.id for f in fields]
field_map = {str(f.id): f for f in fields}
stmt = select(SlotDefinition).where(
SlotDefinition.tenant_id == tenant_id,
)
result = await self._session.execute(stmt)
all_slots = list(result.scalars().all())
slot_with_fields = []
for slot in all_slots:
slot_data = {
"id": str(slot.id),
"tenant_id": slot.tenant_id,
"slot_key": slot.slot_key,
"type": slot.type,
"required": slot.required,
"extract_strategy": slot.extract_strategy,
"validation_rule": slot.validation_rule,
"ask_back_prompt": slot.ask_back_prompt,
"default_value": slot.default_value,
"linked_field_id": str(slot.linked_field_id) if slot.linked_field_id else None,
"created_at": slot.created_at.isoformat() if slot.created_at else None,
"updated_at": slot.updated_at.isoformat() if slot.updated_at else None,
"linked_field": None,
}
if slot.linked_field_id and str(slot.linked_field_id) in field_map:
linked_field = field_map[str(slot.linked_field_id)]
slot_data["linked_field"] = {
"id": str(linked_field.id),
"field_key": linked_field.field_key,
"label": linked_field.label,
"type": linked_field.type,
"required": linked_field.required,
"options": linked_field.options,
"default_value": linked_field.default_value,
"scope": linked_field.scope,
"is_filterable": linked_field.is_filterable,
"is_rank_feature": linked_field.is_rank_feature,
"field_roles": linked_field.field_roles,
"status": linked_field.status,
}
slot_with_fields.append(slot_data)
for field in fields:
field_id_str = str(field.id)
has_slot = any(
s.get("linked_field_id") == field_id_str
for s in slot_with_fields
)
if not has_slot and role == "slot":
slot_with_fields.append({
"id": None,
"tenant_id": field.tenant_id,
"slot_key": field.field_key,
"type": field.type,
"required": field.required,
"extract_strategy": None,
"validation_rule": None,
"ask_back_prompt": None,
"default_value": field.default_value,
"linked_field_id": str(field.id),
"created_at": None,
"updated_at": None,
"linked_field": {
"id": str(field.id),
"field_key": field.field_key,
"label": field.label,
"type": field.type,
"required": field.required,
"options": field.options,
"default_value": field.default_value,
"scope": field.scope,
"is_filterable": field.is_filterable,
"is_rank_feature": field.is_rank_feature,
"field_roles": field.field_roles,
"status": field.status,
},
})
logger.info(
f"[AC-MRS-10] Found {len(slot_with_fields)} slot definitions for role={role}"
)
return slot_with_fields
async def get_resource_filter_fields(
self,
tenant_id: str,
) -> list[MetadataFieldDefinition]:
"""
[AC-MRS-11] 获取资源过滤角色字段
kb_search_dynamic 工具使用
"""
return await self.get_fields_by_role(
tenant_id,
FieldRole.RESOURCE_FILTER.value
)
async def get_slot_fields(
self,
tenant_id: str,
) -> list[MetadataFieldDefinition]:
"""
[AC-MRS-12] 获取槽位角色字段
memory_recall 工具使用
"""
return await self.get_fields_by_role(
tenant_id,
FieldRole.SLOT.value
)
async def get_routing_signal_fields(
self,
tenant_id: str,
) -> list[MetadataFieldDefinition]:
"""
[AC-MRS-13] 获取路由信号角色字段
intent_hint/high_risk_check 工具使用
"""
return await self.get_fields_by_role(
tenant_id,
FieldRole.ROUTING_SIGNAL.value
)
async def get_prompt_var_fields(
self,
tenant_id: str,
) -> list[MetadataFieldDefinition]:
"""
[AC-MRS-14] 获取提示词变量角色字段
template_engine 使用
"""
return await self.get_fields_by_role(
tenant_id,
FieldRole.PROMPT_VAR.value
)