ai-robot-core/ai-service/app/services/flow/template_engine.py

202 lines
6.5 KiB
Python
Raw Normal View History

"""
Template Engine for Intent-Driven Script Flow.
[AC-IDS-06] Template mode script generation with variable filling.
[AC-MRS-14] 只消费 field_roles 包含 prompt_var 的字段
"""
import asyncio
import logging
import re
from typing import Any
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.entities import FieldRole
from app.services.mid.role_based_field_provider import RoleBasedFieldProvider
logger = logging.getLogger(__name__)
class TemplateEngine:
"""
[AC-IDS-06] Template script engine.
[AC-MRS-14] 只消费 field_roles 包含 prompt_var 的字段
Fills template variables using context or LLM generation.
"""
VARIABLE_PATTERN = re.compile(r'\{(\w+)\}')
DEFAULT_TIMEOUT = 5.0
def __init__(self, llm_client: Any = None, session: AsyncSession | None = None):
"""
Initialize TemplateEngine.
Args:
llm_client: LLM client for variable generation (optional)
session: Database session for role-based field provider (optional)
"""
self._llm_client = llm_client
self._session = session
self._role_provider = RoleBasedFieldProvider(session) if session else None
async def fill_template(
self,
template: str,
context: dict[str, Any] | None,
history: list[dict[str, str]] | None,
tenant_id: str | None = None,
) -> str:
"""
[AC-IDS-06] Fill template variables with context or LLM-generated values.
[AC-MRS-14] 只消费 prompt_var 角色的字段
Args:
template: Script template with {variable} placeholders
context: Session context with collected inputs
history: Conversation history for context
tenant_id: Tenant ID for role-based field filtering
Returns:
Filled template string
"""
try:
variables = self.VARIABLE_PATTERN.findall(template)
if not variables:
return template
prompt_var_fields = []
if tenant_id and self._role_provider:
prompt_var_fields = await self._role_provider.get_prompt_var_field_keys(tenant_id)
logger.info(
f"[AC-MRS-14] Retrieved {len(prompt_var_fields)} prompt_var fields for tenant={tenant_id}: {prompt_var_fields}"
)
filtered_context = {}
if context:
if prompt_var_fields:
filtered_context = {k: v for k, v in context.items() if k in prompt_var_fields}
logger.info(
f"[AC-MRS-14] Applied prompt_var context: {list(filtered_context.keys())}"
)
else:
filtered_context = context
variable_values = {}
for var in variables:
value = await self._generate_variable_value(
variable_name=var,
context=filtered_context,
history=history,
)
variable_values[var] = value
result = template
for var, value in variable_values.items():
result = result.replace(f"{{{var}}}", value)
logger.info(
f"[AC-IDS-06] Filled template: "
f"variables={list(variable_values.keys())}"
)
return result
except Exception as e:
logger.error(f"[AC-IDS-06] Template fill failed: {e}, return original")
return template
async def _generate_variable_value(
self,
variable_name: str,
context: dict[str, Any] | None,
history: list[dict[str, str]] | None,
) -> str:
"""
Generate value for a single template variable.
Args:
variable_name: Variable name to generate value for
context: Session context
history: Conversation history
Returns:
Generated variable value
"""
if context and variable_name in context:
return str(context[variable_name])
if context and context.get("inputs"):
for inp in context["inputs"]:
if isinstance(inp, dict):
if inp.get("variable") == variable_name:
return str(inp.get("input", f"[{variable_name}]"))
if self._llm_client:
prompt = self._build_variable_prompt(
variable_name=variable_name,
history=history,
)
try:
messages = [{"role": "user", "content": prompt}]
response = await asyncio.wait_for(
self._llm_client.generate(messages),
timeout=self.DEFAULT_TIMEOUT,
)
value = response.content.strip() if hasattr(response, 'content') else str(response).strip()
return value
except asyncio.TimeoutError:
logger.warning(
f"[AC-IDS-06] Variable generation timeout for {variable_name}"
)
except Exception as e:
logger.warning(
f"[AC-IDS-06] Variable generation failed for {variable_name}: {e}"
)
logger.warning(
f"[AC-IDS-06] Failed to generate value for {variable_name}, "
f"use placeholder"
)
return f"[{variable_name}]"
def _build_variable_prompt(
self,
variable_name: str,
history: list[dict[str, str]] | None,
) -> str:
"""
Build prompt for variable value generation.
"""
prompt_parts = [
f'根据对话历史,为变量 "{variable_name}" 生成合适的值。',
"",
]
if history:
prompt_parts.append("对话历史:")
for msg in history[-3:]:
role = "用户" if msg.get("role") == "user" else "客服"
content = msg.get("content", "")
prompt_parts.append(f"{role}: {content}")
prompt_parts.append("")
prompt_parts.extend([
"只返回变量值,不要解释。",
])
return "\n".join(prompt_parts)
def extract_variables(self, template: str) -> list[str]:
"""
Extract variable names from template.
Args:
template: Template string with {variable} placeholders
Returns:
List of variable names
"""
return self.VARIABLE_PATTERN.findall(template)