feat: implement RoleBasedFieldProvider service for role-based field queries [AC-MRS-04,05,10]
This commit is contained in:
parent
68e5adaa28
commit
0db2971c73
|
|
@ -0,0 +1,299 @@
|
||||||
|
"""
|
||||||
|
Role Based Field Provider Service.
|
||||||
|
[AC-MRS-04, AC-MRS-05, AC-MRS-10] 基于角色的字段提供者服务
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import select, cast
|
||||||
|
from sqlalchemy.dialects.postgresql import JSONB
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models.entities import (
|
||||||
|
FieldRole,
|
||||||
|
MetadataFieldDefinition,
|
||||||
|
MetadataFieldStatus,
|
||||||
|
SlotDefinition,
|
||||||
|
)
|
||||||
|
from app.schemas.metadata import VALID_FIELD_ROLES
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class InvalidRoleError(Exception):
|
||||||
|
"""[AC-MRS-05] 无效角色异常"""
|
||||||
|
|
||||||
|
def __init__(self, role: str):
|
||||||
|
self.role = role
|
||||||
|
self.valid_roles = VALID_FIELD_ROLES
|
||||||
|
super().__init__(
|
||||||
|
f"Invalid role '{role}'. Valid roles are: {', '.join(self.valid_roles)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RoleBasedFieldProvider:
|
||||||
|
"""
|
||||||
|
[AC-MRS-04, AC-MRS-05, AC-MRS-10] 基于角色的字段提供者
|
||||||
|
|
||||||
|
提供按角色查询字段定义的能力,供工具链按需消费
|
||||||
|
"""
|
||||||
|
|
||||||
|
CACHE_KEY_PREFIX = "field_roles"
|
||||||
|
CACHE_TTL = 300
|
||||||
|
|
||||||
|
def __init__(self, session: AsyncSession):
|
||||||
|
self._session = session
|
||||||
|
|
||||||
|
def _validate_role(self, role: str) -> str:
|
||||||
|
"""
|
||||||
|
[AC-MRS-05] 验证角色值是否有效
|
||||||
|
|
||||||
|
Args:
|
||||||
|
role: 角色字符串
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
验证通过的角色字符串
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidRoleError: 如果角色无效
|
||||||
|
"""
|
||||||
|
if role not in VALID_FIELD_ROLES:
|
||||||
|
raise InvalidRoleError(role)
|
||||||
|
return role
|
||||||
|
|
||||||
|
async def get_fields_by_role(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
role: str,
|
||||||
|
include_deprecated: bool = False,
|
||||||
|
) -> list[MetadataFieldDefinition]:
|
||||||
|
"""
|
||||||
|
[AC-MRS-04] 按角色获取字段定义
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: 租户 ID
|
||||||
|
role: 字段角色 (resource_filter/slot/prompt_var/routing_signal)
|
||||||
|
include_deprecated: 是否包含已废弃字段
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
字段定义列表
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
InvalidRoleError: 如果角色无效
|
||||||
|
"""
|
||||||
|
self._validate_role(role)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-MRS-04] Getting fields by role: tenant={tenant_id}, "
|
||||||
|
f"role={role}, include_deprecated={include_deprecated}"
|
||||||
|
)
|
||||||
|
|
||||||
|
stmt = select(MetadataFieldDefinition).where(
|
||||||
|
MetadataFieldDefinition.tenant_id == tenant_id,
|
||||||
|
cast(MetadataFieldDefinition.field_roles, JSONB).op('?')(role),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not include_deprecated:
|
||||||
|
stmt = stmt.where(
|
||||||
|
MetadataFieldDefinition.status == MetadataFieldStatus.ACTIVE.value
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
stmt = stmt.where(
|
||||||
|
MetadataFieldDefinition.status.in_([
|
||||||
|
MetadataFieldStatus.ACTIVE.value,
|
||||||
|
MetadataFieldStatus.DEPRECATED.value,
|
||||||
|
])
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
fields = list(result.scalars().all())
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-MRS-04] Found {len(fields)} fields for role={role}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return fields
|
||||||
|
|
||||||
|
async def get_field_keys_by_role(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
role: str,
|
||||||
|
include_deprecated: bool = False,
|
||||||
|
) -> list[str]:
|
||||||
|
"""
|
||||||
|
按角色获取字段键名列表
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: 租户 ID
|
||||||
|
role: 字段角色
|
||||||
|
include_deprecated: 是否包含已废弃字段
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
字段键名列表
|
||||||
|
"""
|
||||||
|
fields = await self.get_fields_by_role(tenant_id, role, include_deprecated)
|
||||||
|
return [f.field_key for f in fields]
|
||||||
|
|
||||||
|
async def get_slot_definitions_by_role(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
role: str = "slot",
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
[AC-MRS-10] 按角色获取槽位定义及关联字段信息
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: 租户 ID
|
||||||
|
role: 字段角色,默认为 slot
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
槽位定义列表,包含关联字段信息
|
||||||
|
"""
|
||||||
|
self._validate_role(role)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-MRS-10] Getting slot definitions by role: tenant={tenant_id}, role={role}"
|
||||||
|
)
|
||||||
|
|
||||||
|
fields = await self.get_fields_by_role(tenant_id, role)
|
||||||
|
field_ids = [f.id for f in fields]
|
||||||
|
field_map = {str(f.id): f for f in fields}
|
||||||
|
|
||||||
|
stmt = select(SlotDefinition).where(
|
||||||
|
SlotDefinition.tenant_id == tenant_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
all_slots = list(result.scalars().all())
|
||||||
|
|
||||||
|
slot_with_fields = []
|
||||||
|
for slot in all_slots:
|
||||||
|
slot_data = {
|
||||||
|
"id": str(slot.id),
|
||||||
|
"tenant_id": slot.tenant_id,
|
||||||
|
"slot_key": slot.slot_key,
|
||||||
|
"type": slot.type,
|
||||||
|
"required": slot.required,
|
||||||
|
"extract_strategy": slot.extract_strategy,
|
||||||
|
"validation_rule": slot.validation_rule,
|
||||||
|
"ask_back_prompt": slot.ask_back_prompt,
|
||||||
|
"default_value": slot.default_value,
|
||||||
|
"linked_field_id": str(slot.linked_field_id) if slot.linked_field_id else None,
|
||||||
|
"created_at": slot.created_at.isoformat() if slot.created_at else None,
|
||||||
|
"updated_at": slot.updated_at.isoformat() if slot.updated_at else None,
|
||||||
|
"linked_field": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
if slot.linked_field_id and str(slot.linked_field_id) in field_map:
|
||||||
|
linked_field = field_map[str(slot.linked_field_id)]
|
||||||
|
slot_data["linked_field"] = {
|
||||||
|
"id": str(linked_field.id),
|
||||||
|
"field_key": linked_field.field_key,
|
||||||
|
"label": linked_field.label,
|
||||||
|
"type": linked_field.type,
|
||||||
|
"required": linked_field.required,
|
||||||
|
"options": linked_field.options,
|
||||||
|
"default_value": linked_field.default_value,
|
||||||
|
"scope": linked_field.scope,
|
||||||
|
"is_filterable": linked_field.is_filterable,
|
||||||
|
"is_rank_feature": linked_field.is_rank_feature,
|
||||||
|
"field_roles": linked_field.field_roles,
|
||||||
|
"status": linked_field.status,
|
||||||
|
}
|
||||||
|
|
||||||
|
slot_with_fields.append(slot_data)
|
||||||
|
|
||||||
|
for field in fields:
|
||||||
|
field_id_str = str(field.id)
|
||||||
|
has_slot = any(
|
||||||
|
s.get("linked_field_id") == field_id_str
|
||||||
|
for s in slot_with_fields
|
||||||
|
)
|
||||||
|
if not has_slot and role == "slot":
|
||||||
|
slot_with_fields.append({
|
||||||
|
"id": None,
|
||||||
|
"tenant_id": field.tenant_id,
|
||||||
|
"slot_key": field.field_key,
|
||||||
|
"type": field.type,
|
||||||
|
"required": field.required,
|
||||||
|
"extract_strategy": None,
|
||||||
|
"validation_rule": None,
|
||||||
|
"ask_back_prompt": None,
|
||||||
|
"default_value": field.default_value,
|
||||||
|
"linked_field_id": str(field.id),
|
||||||
|
"created_at": None,
|
||||||
|
"updated_at": None,
|
||||||
|
"linked_field": {
|
||||||
|
"id": str(field.id),
|
||||||
|
"field_key": field.field_key,
|
||||||
|
"label": field.label,
|
||||||
|
"type": field.type,
|
||||||
|
"required": field.required,
|
||||||
|
"options": field.options,
|
||||||
|
"default_value": field.default_value,
|
||||||
|
"scope": field.scope,
|
||||||
|
"is_filterable": field.is_filterable,
|
||||||
|
"is_rank_feature": field.is_rank_feature,
|
||||||
|
"field_roles": field.field_roles,
|
||||||
|
"status": field.status,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-MRS-10] Found {len(slot_with_fields)} slot definitions for role={role}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return slot_with_fields
|
||||||
|
|
||||||
|
async def get_resource_filter_fields(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
) -> list[MetadataFieldDefinition]:
|
||||||
|
"""
|
||||||
|
[AC-MRS-11] 获取资源过滤角色字段
|
||||||
|
供 kb_search_dynamic 工具使用
|
||||||
|
"""
|
||||||
|
return await self.get_fields_by_role(
|
||||||
|
tenant_id,
|
||||||
|
FieldRole.RESOURCE_FILTER.value
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_slot_fields(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
) -> list[MetadataFieldDefinition]:
|
||||||
|
"""
|
||||||
|
[AC-MRS-12] 获取槽位角色字段
|
||||||
|
供 memory_recall 工具使用
|
||||||
|
"""
|
||||||
|
return await self.get_fields_by_role(
|
||||||
|
tenant_id,
|
||||||
|
FieldRole.SLOT.value
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_routing_signal_fields(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
) -> list[MetadataFieldDefinition]:
|
||||||
|
"""
|
||||||
|
[AC-MRS-13] 获取路由信号角色字段
|
||||||
|
供 intent_hint/high_risk_check 工具使用
|
||||||
|
"""
|
||||||
|
return await self.get_fields_by_role(
|
||||||
|
tenant_id,
|
||||||
|
FieldRole.ROUTING_SIGNAL.value
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_prompt_var_fields(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
) -> list[MetadataFieldDefinition]:
|
||||||
|
"""
|
||||||
|
[AC-MRS-14] 获取提示词变量角色字段
|
||||||
|
供 template_engine 使用
|
||||||
|
"""
|
||||||
|
return await self.get_fields_by_role(
|
||||||
|
tenant_id,
|
||||||
|
FieldRole.PROMPT_VAR.value
|
||||||
|
)
|
||||||
Loading…
Reference in New Issue