feat: refactor memory_recall_tool to only consume slot role fields [AC-MRS-12]

This commit is contained in:
MerCry 2026-03-05 17:18:37 +08:00
parent 4bd2b76d1c
commit f9fe6ec615
1 changed files with 582 additions and 0 deletions

View File

@ -0,0 +1,582 @@
"""
Memory Recall Tool for Mid Platform.
[AC-IDMP-13] 记忆召回工具 - 短期可用记忆注入
[AC-MRS-12] 只消费 field_roles 包含 slot 的字段
定位短期可用记忆注入不是完整中长期记忆系统
功能读取可用记忆包profile/facts/preferences/last_summary/slots
关键特性
1. 优先读取已有结构化记忆
2. 若缺失使用最近窗口历史做最小回填
3. 槽位冲突优先级user_confirmed > rule_extracted > llm_inferred > default
4. 超时 <= 1000ms失败不抛硬异常
5. 多租户隔离正确
6. 只消费 slot 角色的字段
"""
import asyncio
import logging
import time
from dataclasses import dataclass, field
from typing import Any
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.mid.schemas import (
MemoryRecallResult,
MemorySlot,
SlotSource,
ToolCallStatus,
ToolCallTrace,
ToolType,
)
from app.models.entities import FieldRole
from app.services.mid.role_based_field_provider import RoleBasedFieldProvider
from app.services.mid.timeout_governor import TimeoutGovernor
logger = logging.getLogger(__name__)
DEFAULT_RECALL_TIMEOUT_MS = 1000
DEFAULT_MAX_RECENT_MESSAGES = 8
@dataclass
class MemoryRecallConfig:
"""记忆召回工具配置。"""
enabled: bool = True
timeout_ms: int = DEFAULT_RECALL_TIMEOUT_MS
max_recent_messages: int = DEFAULT_MAX_RECENT_MESSAGES
default_recall_scope: list[str] = field(
default_factory=lambda: ["profile", "facts", "preferences", "summary", "slots"]
)
class MemoryRecallTool:
"""
[AC-IDMP-13] 记忆召回工具
[AC-MRS-12] 只消费 field_roles 包含 slot 的字段
用于在对话前读取用户可用记忆减少重复追问
Features:
- 读取 profile/facts/preferences/last_summary/slots
- 槽位冲突优先级处理
- 超时控制与降级
- 多租户隔离
- 只消费 slot 角色的字段
"""
SLOT_PRIORITY: dict[SlotSource, int] = {
SlotSource.USER_CONFIRMED: 4,
SlotSource.RULE_EXTRACTED: 3,
SlotSource.LLM_INFERRED: 2,
SlotSource.DEFAULT: 1,
}
def __init__(
self,
session: AsyncSession,
timeout_governor: TimeoutGovernor | None = None,
config: MemoryRecallConfig | None = None,
):
self._session = session
self._timeout_governor = timeout_governor or TimeoutGovernor()
self._config = config or MemoryRecallConfig()
self._role_provider = RoleBasedFieldProvider(session)
async def execute(
self,
tenant_id: str,
user_id: str,
session_id: str,
recall_scope: list[str] | None = None,
max_recent_messages: int | None = None,
) -> MemoryRecallResult:
"""
[AC-IDMP-13] 执行记忆召回
Args:
tenant_id: 租户 ID
user_id: 用户 ID
session_id: 会话 ID
recall_scope: 召回范围默认 ["profile","facts","preferences","summary","slots"]
max_recent_messages: 最大最近消息数默认 8
Returns:
MemoryRecallResult: 记忆召回结果
"""
if not self._config.enabled:
logger.info(f"[AC-IDMP-13] Memory recall disabled for tenant={tenant_id}")
return MemoryRecallResult(
fallback_reason_code="MEMORY_RECALL_DISABLED",
)
start_time = time.time()
scope = recall_scope or self._config.default_recall_scope
max_msgs = max_recent_messages or self._config.max_recent_messages
logger.info(
f"[AC-IDMP-13] Starting memory recall: tenant={tenant_id}, "
f"user={user_id}, session={session_id}, scope={scope}"
)
try:
result = await asyncio.wait_for(
self._recall_internal(
tenant_id=tenant_id,
user_id=user_id,
session_id=session_id,
scope=scope,
max_recent_messages=max_msgs,
),
timeout=self._config.timeout_ms / 1000.0,
)
duration_ms = int((time.time() - start_time) * 1000)
result.duration_ms = duration_ms
logger.info(
f"[AC-IDMP-13] Memory recall completed: tenant={tenant_id}, "
f"user={user_id}, duration_ms={duration_ms}, "
f"profile={bool(result.profile)}, facts={len(result.facts)}, "
f"slots={len(result.slots)}, missing_slots={len(result.missing_slots)}"
)
return result
except asyncio.TimeoutError:
duration_ms = int((time.time() - start_time) * 1000)
logger.warning(
f"[AC-IDMP-13] Memory recall timeout: tenant={tenant_id}, "
f"user={user_id}, duration_ms={duration_ms}"
)
return MemoryRecallResult(
fallback_reason_code="MEMORY_RECALL_TIMEOUT",
duration_ms=duration_ms,
)
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
logger.error(
f"[AC-IDMP-13] Memory recall failed: tenant={tenant_id}, "
f"user={user_id}, error={e}"
)
return MemoryRecallResult(
fallback_reason_code=f"MEMORY_RECALL_ERROR:{str(e)[:50]}",
duration_ms=duration_ms,
)
async def _recall_internal(
self,
tenant_id: str,
user_id: str,
session_id: str,
scope: list[str],
max_recent_messages: int,
) -> MemoryRecallResult:
"""内部召回实现。"""
profile: dict[str, Any] = {}
facts: list[str] = []
preferences: dict[str, Any] = {}
last_summary: str | None = None
slots: dict[str, MemorySlot] = {}
missing_slots: list[str] = []
if "profile" in scope:
profile = await self._recall_profile(tenant_id, user_id)
if "facts" in scope:
facts = await self._recall_facts(tenant_id, user_id)
if "preferences" in scope:
preferences = await self._recall_preferences(tenant_id, user_id)
if "summary" in scope:
last_summary = await self._recall_last_summary(tenant_id, user_id)
if "slots" in scope:
slots, missing_slots = await self._recall_slots(
tenant_id, user_id, session_id
)
if not profile and not facts and not preferences and not last_summary and not slots:
if "history" in scope or max_recent_messages > 0:
history_facts = await self._recall_from_history(
tenant_id, session_id, max_recent_messages
)
facts.extend(history_facts)
return MemoryRecallResult(
profile=profile,
facts=facts,
preferences=preferences,
last_summary=last_summary,
slots=slots,
missing_slots=missing_slots,
)
async def _recall_profile(
self,
tenant_id: str,
user_id: str,
) -> dict[str, Any]:
"""召回用户基础属性。"""
try:
from app.models.entities import ChatMessage
from sqlmodel import col
stmt = (
select(ChatMessage)
.where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.role == "user",
)
.order_by(col(ChatMessage.created_at).desc())
.limit(5)
)
result = await self._session.execute(stmt)
messages = result.scalars().all()
profile: dict[str, Any] = {}
for msg in messages:
content = msg.content.lower()
if "年级" in content or "" in content or "" in content:
if "grade" not in profile:
profile["grade"] = self._extract_grade(msg.content)
if "北京" in content or "上海" in content or "广州" in content:
if "region" not in profile:
profile["region"] = self._extract_region(msg.content)
return profile
except Exception as e:
logger.warning(f"[AC-IDMP-13] Failed to recall profile: {e}")
return {}
async def _recall_facts(
self,
tenant_id: str,
user_id: str,
) -> list[str]:
"""召回用户事实记忆。"""
try:
from app.models.entities import ChatMessage
from sqlmodel import col
stmt = (
select(ChatMessage)
.where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.role == "assistant",
)
.order_by(col(ChatMessage.created_at).desc())
.limit(10)
)
result = await self._session.execute(stmt)
messages = result.scalars().all()
facts: list[str] = []
for msg in messages:
content = msg.content
if "已购" in content or "购买" in content:
facts.append(self._extract_purchase_info(content))
if "订单" in content:
facts.append(self._extract_order_info(content))
return [f for f in facts if f]
except Exception as e:
logger.warning(f"[AC-IDMP-13] Failed to recall facts: {e}")
return []
async def _recall_preferences(
self,
tenant_id: str,
user_id: str,
) -> dict[str, Any]:
"""召回用户偏好。"""
try:
from app.models.entities import ChatMessage
from sqlmodel import col
stmt = (
select(ChatMessage)
.where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.role == "user",
)
.order_by(col(ChatMessage.created_at).desc())
.limit(10)
)
result = await self._session.execute(stmt)
messages = result.scalars().all()
preferences: dict[str, Any] = {}
for msg in messages:
content = msg.content.lower()
if "详细" in content or "详细解释" in content:
preferences["communication_style"] = "详细解释"
elif "简单" in content or "简洁" in content:
preferences["communication_style"] = "简洁"
return preferences
except Exception as e:
logger.warning(f"[AC-IDMP-13] Failed to recall preferences: {e}")
return {}
async def _recall_last_summary(
self,
tenant_id: str,
user_id: str,
) -> str | None:
"""召回最近会话摘要。"""
try:
from app.models.entities import MidAuditLog
from sqlmodel import col
stmt = (
select(MidAuditLog)
.where(
MidAuditLog.tenant_id == tenant_id,
)
.order_by(col(MidAuditLog.created_at).desc())
.limit(1)
)
result = await self._session.execute(stmt)
audit = result.scalar_one_or_none()
if audit:
return f"上次会话模式: {audit.mode}"
return None
except Exception as e:
logger.warning(f"[AC-IDMP-13] Failed to recall last summary: {e}")
return None
async def _recall_slots(
self,
tenant_id: str,
user_id: str,
session_id: str,
) -> tuple[dict[str, MemorySlot], list[str]]:
"""
[AC-MRS-12] 召回结构化槽位只消费 slot 角色的字段
Returns:
Tuple of (slots_dict, missing_required_slots)
"""
try:
slot_field_keys = await self._role_provider.get_slot_field_keys(tenant_id)
logger.info(
f"[AC-MRS-12] Retrieved {len(slot_field_keys)} slot fields for tenant={tenant_id}: {slot_field_keys}"
)
from app.models.entities import FlowInstance
from sqlalchemy import desc
stmt = (
select(FlowInstance)
.where(
FlowInstance.tenant_id == tenant_id,
)
.order_by(desc(FlowInstance.updated_at))
.limit(1)
)
result = await self._session.execute(stmt)
flow_instance = result.scalar_one_or_none()
slots: dict[str, MemorySlot] = {}
missing_slots: list[str] = []
if flow_instance and flow_instance.context:
context = flow_instance.context
for key, value in context.items():
if key in slot_field_keys and value is not None:
slots[key] = MemorySlot(
key=key,
value=value,
source=SlotSource.USER_CONFIRMED,
confidence=1.0,
updated_at=str(flow_instance.updated_at),
)
return slots, missing_slots
except Exception as e:
logger.warning(f"[AC-IDMP-13] Failed to recall slots: {e}")
return {}, []
async def _recall_from_history(
self,
tenant_id: str,
session_id: str,
max_messages: int,
) -> list[str]:
"""从最近历史中提取最小回填信息。"""
try:
from app.models.entities import ChatMessage
from sqlmodel import col
stmt = (
select(ChatMessage)
.where(
ChatMessage.tenant_id == tenant_id,
ChatMessage.session_id == session_id,
)
.order_by(col(ChatMessage.created_at).desc())
.limit(max_messages)
)
result = await self._session.execute(stmt)
messages = result.scalars().all()
facts: list[str] = []
for msg in messages:
if msg.role == "user":
facts.append(f"用户说过: {msg.content[:50]}")
return facts
except Exception as e:
logger.warning(f"[AC-IDMP-13] Failed to recall from history: {e}")
return []
def _extract_grade(self, content: str) -> str:
"""从内容中提取年级信息。"""
grades = ["初一", "初二", "初三", "高一", "高二", "高三"]
for grade in grades:
if grade in content:
return grade
return "未知年级"
def _extract_region(self, content: str) -> str:
"""从内容中提取地区信息。"""
regions = ["北京", "上海", "广州", "深圳", "杭州", "成都", "武汉", "南京"]
for region in regions:
if region in content:
return region
return "未知地区"
def _extract_purchase_info(self, content: str) -> str:
"""从内容中提取购买信息。"""
return f"购买记录: {content[:30]}..."
def _extract_order_info(self, content: str) -> str:
"""从内容中提取订单信息。"""
return f"订单信息: {content[:30]}..."
def merge_slots(
self,
existing_slots: dict[str, MemorySlot],
new_slots: dict[str, MemorySlot],
) -> dict[str, MemorySlot]:
"""
合并槽位按优先级处理冲突
优先级user_confirmed > rule_extracted > llm_inferred > default
"""
merged = dict(existing_slots)
for key, new_slot in new_slots.items():
if key not in merged:
merged[key] = new_slot
else:
existing_slot = merged[key]
existing_priority = self.SLOT_PRIORITY.get(existing_slot.source, 0)
new_priority = self.SLOT_PRIORITY.get(new_slot.source, 0)
if new_priority > existing_priority:
merged[key] = new_slot
elif new_priority == existing_priority:
if new_slot.confidence > existing_slot.confidence:
merged[key] = new_slot
return merged
def create_trace(
self,
result: MemoryRecallResult,
tenant_id: str,
) -> ToolCallTrace:
"""创建工具调用追踪记录。"""
status = ToolCallStatus.OK
error_code = None
if result.fallback_reason_code:
if "TIMEOUT" in result.fallback_reason_code:
status = ToolCallStatus.TIMEOUT
else:
status = ToolCallStatus.ERROR
error_code = result.fallback_reason_code
return ToolCallTrace(
tool_name="memory_recall",
tool_type=ToolType.INTERNAL,
duration_ms=result.duration_ms,
status=status,
error_code=error_code,
args_digest=f"tenant={tenant_id}",
result_digest=f"profile={len(result.profile)}, facts={len(result.facts)}, slots={len(result.slots)}",
)
def register_memory_recall_tool(
registry: Any,
session: AsyncSession,
timeout_governor: TimeoutGovernor | None = None,
config: MemoryRecallConfig | None = None,
) -> None:
"""
[AC-IDMP-13] 注册 memory_recall 工具到 ToolRegistry
Args:
registry: ToolRegistry 实例
session: 数据库会话
timeout_governor: 超时治理器
config: 工具配置
"""
cfg = config or MemoryRecallConfig()
async def memory_recall_handler(
tenant_id: str,
user_id: str,
session_id: str,
recall_scope: list[str] | None = None,
max_recent_messages: int | None = None,
) -> dict[str, Any]:
"""memory_recall 工具处理器。"""
tool = MemoryRecallTool(
session=session,
timeout_governor=timeout_governor,
config=cfg,
)
result = await tool.execute(
tenant_id=tenant_id,
user_id=user_id,
session_id=session_id,
recall_scope=recall_scope,
max_recent_messages=max_recent_messages,
)
return result.model_dump()
registry.register(
name="memory_recall",
description="[AC-IDMP-13] 记忆召回工具读取用户可用记忆包profile/facts/preferences/summary/slots",
handler=memory_recall_handler,
tool_type=ToolType.INTERNAL,
version="1.0.0",
auth_required=False,
timeout_ms=min(cfg.timeout_ms, 1000),
enabled=True,
metadata={
"ac_ids": ["AC-IDMP-13"],
"recall_scope": cfg.default_recall_scope,
"max_recent_messages": cfg.max_recent_messages,
},
)
logger.info("[AC-IDMP-13] memory_recall tool registered to registry")