diff --git a/ai-service/app/services/flow/template_engine.py b/ai-service/app/services/flow/template_engine.py index 5e48551..d098f17 100644 --- a/ai-service/app/services/flow/template_engine.py +++ b/ai-service/app/services/flow/template_engine.py @@ -1,6 +1,7 @@ """ Template Engine for Intent-Driven Script Flow. [AC-IDS-06] Template mode script generation with variable filling. +[AC-MRS-14] 只消费 field_roles 包含 prompt_var 的字段 """ import asyncio @@ -8,40 +9,53 @@ import logging import re from typing import Any +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.entities import FieldRole +from app.services.mid.role_based_field_provider import RoleBasedFieldProvider + logger = logging.getLogger(__name__) class TemplateEngine: """ [AC-IDS-06] Template script engine. + [AC-MRS-14] 只消费 field_roles 包含 prompt_var 的字段 + Fills template variables using context or LLM generation. """ VARIABLE_PATTERN = re.compile(r'\{(\w+)\}') DEFAULT_TIMEOUT = 5.0 - def __init__(self, llm_client: Any = None): + def __init__(self, llm_client: Any = None, session: AsyncSession | None = None): """ Initialize TemplateEngine. Args: llm_client: LLM client for variable generation (optional) + session: Database session for role-based field provider (optional) """ self._llm_client = llm_client + self._session = session + self._role_provider = RoleBasedFieldProvider(session) if session else None async def fill_template( self, template: str, context: dict[str, Any] | None, history: list[dict[str, str]] | None, + tenant_id: str | None = None, ) -> str: """ [AC-IDS-06] Fill template variables with context or LLM-generated values. + [AC-MRS-14] 只消费 prompt_var 角色的字段 Args: template: Script template with {variable} placeholders context: Session context with collected inputs history: Conversation history for context + tenant_id: Tenant ID for role-based field filtering Returns: Filled template string @@ -52,11 +66,28 @@ class TemplateEngine: if not variables: return template + prompt_var_fields = [] + if tenant_id and self._role_provider: + prompt_var_fields = await self._role_provider.get_prompt_var_field_keys(tenant_id) + logger.info( + f"[AC-MRS-14] Retrieved {len(prompt_var_fields)} prompt_var fields for tenant={tenant_id}: {prompt_var_fields}" + ) + + filtered_context = {} + if context: + if prompt_var_fields: + filtered_context = {k: v for k, v in context.items() if k in prompt_var_fields} + logger.info( + f"[AC-MRS-14] Applied prompt_var context: {list(filtered_context.keys())}" + ) + else: + filtered_context = context + variable_values = {} for var in variables: value = await self._generate_variable_value( variable_name=var, - context=context, + context=filtered_context, history=history, ) variable_values[var] = value