feat: refactor template_engine to only consume prompt_var role fields [AC-MRS-14]

This commit is contained in:
MerCry 2026-03-05 17:19:53 +08:00
parent 6e7c162195
commit 662ba2b101
1 changed files with 33 additions and 2 deletions

View File

@ -1,6 +1,7 @@
""" """
Template Engine for Intent-Driven Script Flow. Template Engine for Intent-Driven Script Flow.
[AC-IDS-06] Template mode script generation with variable filling. [AC-IDS-06] Template mode script generation with variable filling.
[AC-MRS-14] 只消费 field_roles 包含 prompt_var 的字段
""" """
import asyncio import asyncio
@ -8,40 +9,53 @@ import logging
import re import re
from typing import Any from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.entities import FieldRole
from app.services.mid.role_based_field_provider import RoleBasedFieldProvider
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class TemplateEngine: class TemplateEngine:
""" """
[AC-IDS-06] Template script engine. [AC-IDS-06] Template script engine.
[AC-MRS-14] 只消费 field_roles 包含 prompt_var 的字段
Fills template variables using context or LLM generation. Fills template variables using context or LLM generation.
""" """
VARIABLE_PATTERN = re.compile(r'\{(\w+)\}') VARIABLE_PATTERN = re.compile(r'\{(\w+)\}')
DEFAULT_TIMEOUT = 5.0 DEFAULT_TIMEOUT = 5.0
def __init__(self, llm_client: Any = None): def __init__(self, llm_client: Any = None, session: AsyncSession | None = None):
""" """
Initialize TemplateEngine. Initialize TemplateEngine.
Args: Args:
llm_client: LLM client for variable generation (optional) llm_client: LLM client for variable generation (optional)
session: Database session for role-based field provider (optional)
""" """
self._llm_client = llm_client self._llm_client = llm_client
self._session = session
self._role_provider = RoleBasedFieldProvider(session) if session else None
async def fill_template( async def fill_template(
self, self,
template: str, template: str,
context: dict[str, Any] | None, context: dict[str, Any] | None,
history: list[dict[str, str]] | None, history: list[dict[str, str]] | None,
tenant_id: str | None = None,
) -> str: ) -> str:
""" """
[AC-IDS-06] Fill template variables with context or LLM-generated values. [AC-IDS-06] Fill template variables with context or LLM-generated values.
[AC-MRS-14] 只消费 prompt_var 角色的字段
Args: Args:
template: Script template with {variable} placeholders template: Script template with {variable} placeholders
context: Session context with collected inputs context: Session context with collected inputs
history: Conversation history for context history: Conversation history for context
tenant_id: Tenant ID for role-based field filtering
Returns: Returns:
Filled template string Filled template string
@ -52,11 +66,28 @@ class TemplateEngine:
if not variables: if not variables:
return template return template
prompt_var_fields = []
if tenant_id and self._role_provider:
prompt_var_fields = await self._role_provider.get_prompt_var_field_keys(tenant_id)
logger.info(
f"[AC-MRS-14] Retrieved {len(prompt_var_fields)} prompt_var fields for tenant={tenant_id}: {prompt_var_fields}"
)
filtered_context = {}
if context:
if prompt_var_fields:
filtered_context = {k: v for k, v in context.items() if k in prompt_var_fields}
logger.info(
f"[AC-MRS-14] Applied prompt_var context: {list(filtered_context.keys())}"
)
else:
filtered_context = context
variable_values = {} variable_values = {}
for var in variables: for var in variables:
value = await self._generate_variable_value( value = await self._generate_variable_value(
variable_name=var, variable_name=var,
context=context, context=filtered_context,
history=history, history=history,
) )
variable_values[var] = value variable_values[var] = value