feat: refactor template_engine to only consume prompt_var role fields [AC-MRS-14]

This commit is contained in:
MerCry 2026-03-05 17:19:53 +08:00
parent 6e7c162195
commit 662ba2b101
1 changed files with 33 additions and 2 deletions

View File

@ -1,6 +1,7 @@
"""
Template Engine for Intent-Driven Script Flow.
[AC-IDS-06] Template mode script generation with variable filling.
[AC-MRS-14] 只消费 field_roles 包含 prompt_var 的字段
"""
import asyncio
@ -8,40 +9,53 @@ import logging
import re
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.entities import FieldRole
from app.services.mid.role_based_field_provider import RoleBasedFieldProvider
logger = logging.getLogger(__name__)
class TemplateEngine:
"""
[AC-IDS-06] Template script engine.
[AC-MRS-14] 只消费 field_roles 包含 prompt_var 的字段
Fills template variables using context or LLM generation.
"""
VARIABLE_PATTERN = re.compile(r'\{(\w+)\}')
DEFAULT_TIMEOUT = 5.0
def __init__(self, llm_client: Any = None):
def __init__(self, llm_client: Any = None, session: AsyncSession | None = None):
"""
Initialize TemplateEngine.
Args:
llm_client: LLM client for variable generation (optional)
session: Database session for role-based field provider (optional)
"""
self._llm_client = llm_client
self._session = session
self._role_provider = RoleBasedFieldProvider(session) if session else None
async def fill_template(
self,
template: str,
context: dict[str, Any] | None,
history: list[dict[str, str]] | None,
tenant_id: str | None = None,
) -> str:
"""
[AC-IDS-06] Fill template variables with context or LLM-generated values.
[AC-MRS-14] 只消费 prompt_var 角色的字段
Args:
template: Script template with {variable} placeholders
context: Session context with collected inputs
history: Conversation history for context
tenant_id: Tenant ID for role-based field filtering
Returns:
Filled template string
@ -52,11 +66,28 @@ class TemplateEngine:
if not variables:
return template
prompt_var_fields = []
if tenant_id and self._role_provider:
prompt_var_fields = await self._role_provider.get_prompt_var_field_keys(tenant_id)
logger.info(
f"[AC-MRS-14] Retrieved {len(prompt_var_fields)} prompt_var fields for tenant={tenant_id}: {prompt_var_fields}"
)
filtered_context = {}
if context:
if prompt_var_fields:
filtered_context = {k: v for k, v in context.items() if k in prompt_var_fields}
logger.info(
f"[AC-MRS-14] Applied prompt_var context: {list(filtered_context.keys())}"
)
else:
filtered_context = context
variable_values = {}
for var in variables:
value = await self._generate_variable_value(
variable_name=var,
context=context,
context=filtered_context,
history=history,
)
variable_values[var] = value