diff --git a/ai-service/app/services/mid/memory_recall_tool.py b/ai-service/app/services/mid/memory_recall_tool.py new file mode 100644 index 0000000..e301c78 --- /dev/null +++ b/ai-service/app/services/mid/memory_recall_tool.py @@ -0,0 +1,582 @@ +""" +Memory Recall Tool for Mid Platform. +[AC-IDMP-13] 记忆召回工具 - 短期可用记忆注入 +[AC-MRS-12] 只消费 field_roles 包含 slot 的字段 + +定位:短期可用记忆注入,不是完整中长期记忆系统。 +功能:读取可用记忆包(profile/facts/preferences/last_summary/slots) + +关键特性: +1. 优先读取已有结构化记忆 +2. 若缺失,使用最近窗口历史做最小回填 +3. 槽位冲突优先级:user_confirmed > rule_extracted > llm_inferred > default +4. 超时 <= 1000ms,失败不抛硬异常 +5. 多租户隔离正确 +6. 只消费 slot 角色的字段 +""" + +import asyncio +import logging +import time +from dataclasses import dataclass, field +from typing import Any + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.models.mid.schemas import ( + MemoryRecallResult, + MemorySlot, + SlotSource, + ToolCallStatus, + ToolCallTrace, + ToolType, +) +from app.models.entities import FieldRole +from app.services.mid.role_based_field_provider import RoleBasedFieldProvider +from app.services.mid.timeout_governor import TimeoutGovernor + +logger = logging.getLogger(__name__) + +DEFAULT_RECALL_TIMEOUT_MS = 1000 +DEFAULT_MAX_RECENT_MESSAGES = 8 + + +@dataclass +class MemoryRecallConfig: + """记忆召回工具配置。""" + enabled: bool = True + timeout_ms: int = DEFAULT_RECALL_TIMEOUT_MS + max_recent_messages: int = DEFAULT_MAX_RECENT_MESSAGES + default_recall_scope: list[str] = field( + default_factory=lambda: ["profile", "facts", "preferences", "summary", "slots"] + ) + + +class MemoryRecallTool: + """ + [AC-IDMP-13] 记忆召回工具。 + [AC-MRS-12] 只消费 field_roles 包含 slot 的字段 + + 用于在对话前读取用户可用记忆,减少重复追问。 + + Features: + - 读取 profile/facts/preferences/last_summary/slots + - 槽位冲突优先级处理 + - 超时控制与降级 + - 多租户隔离 + - 只消费 slot 角色的字段 + """ + + SLOT_PRIORITY: dict[SlotSource, int] = { + SlotSource.USER_CONFIRMED: 4, + SlotSource.RULE_EXTRACTED: 3, + SlotSource.LLM_INFERRED: 2, + SlotSource.DEFAULT: 1, + } + + def __init__( + self, + session: AsyncSession, + timeout_governor: TimeoutGovernor | None = None, + config: MemoryRecallConfig | None = None, + ): + self._session = session + self._timeout_governor = timeout_governor or TimeoutGovernor() + self._config = config or MemoryRecallConfig() + self._role_provider = RoleBasedFieldProvider(session) + + async def execute( + self, + tenant_id: str, + user_id: str, + session_id: str, + recall_scope: list[str] | None = None, + max_recent_messages: int | None = None, + ) -> MemoryRecallResult: + """ + [AC-IDMP-13] 执行记忆召回。 + + Args: + tenant_id: 租户 ID + user_id: 用户 ID + session_id: 会话 ID + recall_scope: 召回范围,默认 ["profile","facts","preferences","summary","slots"] + max_recent_messages: 最大最近消息数,默认 8 + + Returns: + MemoryRecallResult: 记忆召回结果 + """ + if not self._config.enabled: + logger.info(f"[AC-IDMP-13] Memory recall disabled for tenant={tenant_id}") + return MemoryRecallResult( + fallback_reason_code="MEMORY_RECALL_DISABLED", + ) + + start_time = time.time() + scope = recall_scope or self._config.default_recall_scope + max_msgs = max_recent_messages or self._config.max_recent_messages + + logger.info( + f"[AC-IDMP-13] Starting memory recall: tenant={tenant_id}, " + f"user={user_id}, session={session_id}, scope={scope}" + ) + + try: + result = await asyncio.wait_for( + self._recall_internal( + tenant_id=tenant_id, + user_id=user_id, + session_id=session_id, + scope=scope, + max_recent_messages=max_msgs, + ), + timeout=self._config.timeout_ms / 1000.0, + ) + + duration_ms = int((time.time() - start_time) * 1000) + result.duration_ms = duration_ms + + logger.info( + f"[AC-IDMP-13] Memory recall completed: tenant={tenant_id}, " + f"user={user_id}, duration_ms={duration_ms}, " + f"profile={bool(result.profile)}, facts={len(result.facts)}, " + f"slots={len(result.slots)}, missing_slots={len(result.missing_slots)}" + ) + + return result + + except asyncio.TimeoutError: + duration_ms = int((time.time() - start_time) * 1000) + logger.warning( + f"[AC-IDMP-13] Memory recall timeout: tenant={tenant_id}, " + f"user={user_id}, duration_ms={duration_ms}" + ) + return MemoryRecallResult( + fallback_reason_code="MEMORY_RECALL_TIMEOUT", + duration_ms=duration_ms, + ) + + except Exception as e: + duration_ms = int((time.time() - start_time) * 1000) + logger.error( + f"[AC-IDMP-13] Memory recall failed: tenant={tenant_id}, " + f"user={user_id}, error={e}" + ) + return MemoryRecallResult( + fallback_reason_code=f"MEMORY_RECALL_ERROR:{str(e)[:50]}", + duration_ms=duration_ms, + ) + + async def _recall_internal( + self, + tenant_id: str, + user_id: str, + session_id: str, + scope: list[str], + max_recent_messages: int, + ) -> MemoryRecallResult: + """内部召回实现。""" + profile: dict[str, Any] = {} + facts: list[str] = [] + preferences: dict[str, Any] = {} + last_summary: str | None = None + slots: dict[str, MemorySlot] = {} + missing_slots: list[str] = [] + + if "profile" in scope: + profile = await self._recall_profile(tenant_id, user_id) + + if "facts" in scope: + facts = await self._recall_facts(tenant_id, user_id) + + if "preferences" in scope: + preferences = await self._recall_preferences(tenant_id, user_id) + + if "summary" in scope: + last_summary = await self._recall_last_summary(tenant_id, user_id) + + if "slots" in scope: + slots, missing_slots = await self._recall_slots( + tenant_id, user_id, session_id + ) + + if not profile and not facts and not preferences and not last_summary and not slots: + if "history" in scope or max_recent_messages > 0: + history_facts = await self._recall_from_history( + tenant_id, session_id, max_recent_messages + ) + facts.extend(history_facts) + + return MemoryRecallResult( + profile=profile, + facts=facts, + preferences=preferences, + last_summary=last_summary, + slots=slots, + missing_slots=missing_slots, + ) + + async def _recall_profile( + self, + tenant_id: str, + user_id: str, + ) -> dict[str, Any]: + """召回用户基础属性。""" + try: + from app.models.entities import ChatMessage + from sqlmodel import col + + stmt = ( + select(ChatMessage) + .where( + ChatMessage.tenant_id == tenant_id, + ChatMessage.role == "user", + ) + .order_by(col(ChatMessage.created_at).desc()) + .limit(5) + ) + result = await self._session.execute(stmt) + messages = result.scalars().all() + + profile: dict[str, Any] = {} + for msg in messages: + content = msg.content.lower() + if "年级" in content or "初" in content or "高" in content: + if "grade" not in profile: + profile["grade"] = self._extract_grade(msg.content) + if "北京" in content or "上海" in content or "广州" in content: + if "region" not in profile: + profile["region"] = self._extract_region(msg.content) + + return profile + + except Exception as e: + logger.warning(f"[AC-IDMP-13] Failed to recall profile: {e}") + return {} + + async def _recall_facts( + self, + tenant_id: str, + user_id: str, + ) -> list[str]: + """召回用户事实记忆。""" + try: + from app.models.entities import ChatMessage + from sqlmodel import col + + stmt = ( + select(ChatMessage) + .where( + ChatMessage.tenant_id == tenant_id, + ChatMessage.role == "assistant", + ) + .order_by(col(ChatMessage.created_at).desc()) + .limit(10) + ) + result = await self._session.execute(stmt) + messages = result.scalars().all() + + facts: list[str] = [] + for msg in messages: + content = msg.content + if "已购" in content or "购买" in content: + facts.append(self._extract_purchase_info(content)) + if "订单" in content: + facts.append(self._extract_order_info(content)) + + return [f for f in facts if f] + + except Exception as e: + logger.warning(f"[AC-IDMP-13] Failed to recall facts: {e}") + return [] + + async def _recall_preferences( + self, + tenant_id: str, + user_id: str, + ) -> dict[str, Any]: + """召回用户偏好。""" + try: + from app.models.entities import ChatMessage + from sqlmodel import col + + stmt = ( + select(ChatMessage) + .where( + ChatMessage.tenant_id == tenant_id, + ChatMessage.role == "user", + ) + .order_by(col(ChatMessage.created_at).desc()) + .limit(10) + ) + result = await self._session.execute(stmt) + messages = result.scalars().all() + + preferences: dict[str, Any] = {} + for msg in messages: + content = msg.content.lower() + if "详细" in content or "详细解释" in content: + preferences["communication_style"] = "详细解释" + elif "简单" in content or "简洁" in content: + preferences["communication_style"] = "简洁" + + return preferences + + except Exception as e: + logger.warning(f"[AC-IDMP-13] Failed to recall preferences: {e}") + return {} + + async def _recall_last_summary( + self, + tenant_id: str, + user_id: str, + ) -> str | None: + """召回最近会话摘要。""" + try: + from app.models.entities import MidAuditLog + from sqlmodel import col + + stmt = ( + select(MidAuditLog) + .where( + MidAuditLog.tenant_id == tenant_id, + ) + .order_by(col(MidAuditLog.created_at).desc()) + .limit(1) + ) + result = await self._session.execute(stmt) + audit = result.scalar_one_or_none() + + if audit: + return f"上次会话模式: {audit.mode}" + + return None + + except Exception as e: + logger.warning(f"[AC-IDMP-13] Failed to recall last summary: {e}") + return None + + async def _recall_slots( + self, + tenant_id: str, + user_id: str, + session_id: str, + ) -> tuple[dict[str, MemorySlot], list[str]]: + """ + [AC-MRS-12] 召回结构化槽位,只消费 slot 角色的字段。 + + Returns: + Tuple of (slots_dict, missing_required_slots) + """ + try: + slot_field_keys = await self._role_provider.get_slot_field_keys(tenant_id) + + logger.info( + f"[AC-MRS-12] Retrieved {len(slot_field_keys)} slot fields for tenant={tenant_id}: {slot_field_keys}" + ) + + from app.models.entities import FlowInstance + from sqlalchemy import desc + + stmt = ( + select(FlowInstance) + .where( + FlowInstance.tenant_id == tenant_id, + ) + .order_by(desc(FlowInstance.updated_at)) + .limit(1) + ) + result = await self._session.execute(stmt) + flow_instance = result.scalar_one_or_none() + + slots: dict[str, MemorySlot] = {} + missing_slots: list[str] = [] + + if flow_instance and flow_instance.context: + context = flow_instance.context + for key, value in context.items(): + if key in slot_field_keys and value is not None: + slots[key] = MemorySlot( + key=key, + value=value, + source=SlotSource.USER_CONFIRMED, + confidence=1.0, + updated_at=str(flow_instance.updated_at), + ) + + return slots, missing_slots + + except Exception as e: + logger.warning(f"[AC-IDMP-13] Failed to recall slots: {e}") + return {}, [] + + async def _recall_from_history( + self, + tenant_id: str, + session_id: str, + max_messages: int, + ) -> list[str]: + """从最近历史中提取最小回填信息。""" + try: + from app.models.entities import ChatMessage + from sqlmodel import col + + stmt = ( + select(ChatMessage) + .where( + ChatMessage.tenant_id == tenant_id, + ChatMessage.session_id == session_id, + ) + .order_by(col(ChatMessage.created_at).desc()) + .limit(max_messages) + ) + result = await self._session.execute(stmt) + messages = result.scalars().all() + + facts: list[str] = [] + for msg in messages: + if msg.role == "user": + facts.append(f"用户说过: {msg.content[:50]}") + + return facts + + except Exception as e: + logger.warning(f"[AC-IDMP-13] Failed to recall from history: {e}") + return [] + + def _extract_grade(self, content: str) -> str: + """从内容中提取年级信息。""" + grades = ["初一", "初二", "初三", "高一", "高二", "高三"] + for grade in grades: + if grade in content: + return grade + return "未知年级" + + def _extract_region(self, content: str) -> str: + """从内容中提取地区信息。""" + regions = ["北京", "上海", "广州", "深圳", "杭州", "成都", "武汉", "南京"] + for region in regions: + if region in content: + return region + return "未知地区" + + def _extract_purchase_info(self, content: str) -> str: + """从内容中提取购买信息。""" + return f"购买记录: {content[:30]}..." + + def _extract_order_info(self, content: str) -> str: + """从内容中提取订单信息。""" + return f"订单信息: {content[:30]}..." + + def merge_slots( + self, + existing_slots: dict[str, MemorySlot], + new_slots: dict[str, MemorySlot], + ) -> dict[str, MemorySlot]: + """ + 合并槽位,按优先级处理冲突。 + + 优先级:user_confirmed > rule_extracted > llm_inferred > default + """ + merged = dict(existing_slots) + + for key, new_slot in new_slots.items(): + if key not in merged: + merged[key] = new_slot + else: + existing_slot = merged[key] + existing_priority = self.SLOT_PRIORITY.get(existing_slot.source, 0) + new_priority = self.SLOT_PRIORITY.get(new_slot.source, 0) + + if new_priority > existing_priority: + merged[key] = new_slot + elif new_priority == existing_priority: + if new_slot.confidence > existing_slot.confidence: + merged[key] = new_slot + + return merged + + def create_trace( + self, + result: MemoryRecallResult, + tenant_id: str, + ) -> ToolCallTrace: + """创建工具调用追踪记录。""" + status = ToolCallStatus.OK + error_code = None + + if result.fallback_reason_code: + if "TIMEOUT" in result.fallback_reason_code: + status = ToolCallStatus.TIMEOUT + else: + status = ToolCallStatus.ERROR + error_code = result.fallback_reason_code + + return ToolCallTrace( + tool_name="memory_recall", + tool_type=ToolType.INTERNAL, + duration_ms=result.duration_ms, + status=status, + error_code=error_code, + args_digest=f"tenant={tenant_id}", + result_digest=f"profile={len(result.profile)}, facts={len(result.facts)}, slots={len(result.slots)}", + ) + + +def register_memory_recall_tool( + registry: Any, + session: AsyncSession, + timeout_governor: TimeoutGovernor | None = None, + config: MemoryRecallConfig | None = None, +) -> None: + """ + [AC-IDMP-13] 注册 memory_recall 工具到 ToolRegistry。 + + Args: + registry: ToolRegistry 实例 + session: 数据库会话 + timeout_governor: 超时治理器 + config: 工具配置 + """ + cfg = config or MemoryRecallConfig() + + async def memory_recall_handler( + tenant_id: str, + user_id: str, + session_id: str, + recall_scope: list[str] | None = None, + max_recent_messages: int | None = None, + ) -> dict[str, Any]: + """memory_recall 工具处理器。""" + tool = MemoryRecallTool( + session=session, + timeout_governor=timeout_governor, + config=cfg, + ) + result = await tool.execute( + tenant_id=tenant_id, + user_id=user_id, + session_id=session_id, + recall_scope=recall_scope, + max_recent_messages=max_recent_messages, + ) + return result.model_dump() + + registry.register( + name="memory_recall", + description="[AC-IDMP-13] 记忆召回工具,读取用户可用记忆包(profile/facts/preferences/summary/slots)", + handler=memory_recall_handler, + tool_type=ToolType.INTERNAL, + version="1.0.0", + auth_required=False, + timeout_ms=min(cfg.timeout_ms, 1000), + enabled=True, + metadata={ + "ac_ids": ["AC-IDMP-13"], + "recall_scope": cfg.default_recall_scope, + "max_recent_messages": cfg.max_recent_messages, + }, + ) + + logger.info("[AC-IDMP-13] memory_recall tool registered to registry")