feat: refactor memory_recall_tool to only consume slot role fields [AC-MRS-12]
This commit is contained in:
parent
4bd2b76d1c
commit
f9fe6ec615
|
|
@ -0,0 +1,582 @@
|
||||||
|
"""
|
||||||
|
Memory Recall Tool for Mid Platform.
|
||||||
|
[AC-IDMP-13] 记忆召回工具 - 短期可用记忆注入
|
||||||
|
[AC-MRS-12] 只消费 field_roles 包含 slot 的字段
|
||||||
|
|
||||||
|
定位:短期可用记忆注入,不是完整中长期记忆系统。
|
||||||
|
功能:读取可用记忆包(profile/facts/preferences/last_summary/slots)
|
||||||
|
|
||||||
|
关键特性:
|
||||||
|
1. 优先读取已有结构化记忆
|
||||||
|
2. 若缺失,使用最近窗口历史做最小回填
|
||||||
|
3. 槽位冲突优先级:user_confirmed > rule_extracted > llm_inferred > default
|
||||||
|
4. 超时 <= 1000ms,失败不抛硬异常
|
||||||
|
5. 多租户隔离正确
|
||||||
|
6. 只消费 slot 角色的字段
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models.mid.schemas import (
|
||||||
|
MemoryRecallResult,
|
||||||
|
MemorySlot,
|
||||||
|
SlotSource,
|
||||||
|
ToolCallStatus,
|
||||||
|
ToolCallTrace,
|
||||||
|
ToolType,
|
||||||
|
)
|
||||||
|
from app.models.entities import FieldRole
|
||||||
|
from app.services.mid.role_based_field_provider import RoleBasedFieldProvider
|
||||||
|
from app.services.mid.timeout_governor import TimeoutGovernor
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
DEFAULT_RECALL_TIMEOUT_MS = 1000
|
||||||
|
DEFAULT_MAX_RECENT_MESSAGES = 8
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MemoryRecallConfig:
|
||||||
|
"""记忆召回工具配置。"""
|
||||||
|
enabled: bool = True
|
||||||
|
timeout_ms: int = DEFAULT_RECALL_TIMEOUT_MS
|
||||||
|
max_recent_messages: int = DEFAULT_MAX_RECENT_MESSAGES
|
||||||
|
default_recall_scope: list[str] = field(
|
||||||
|
default_factory=lambda: ["profile", "facts", "preferences", "summary", "slots"]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryRecallTool:
|
||||||
|
"""
|
||||||
|
[AC-IDMP-13] 记忆召回工具。
|
||||||
|
[AC-MRS-12] 只消费 field_roles 包含 slot 的字段
|
||||||
|
|
||||||
|
用于在对话前读取用户可用记忆,减少重复追问。
|
||||||
|
|
||||||
|
Features:
|
||||||
|
- 读取 profile/facts/preferences/last_summary/slots
|
||||||
|
- 槽位冲突优先级处理
|
||||||
|
- 超时控制与降级
|
||||||
|
- 多租户隔离
|
||||||
|
- 只消费 slot 角色的字段
|
||||||
|
"""
|
||||||
|
|
||||||
|
SLOT_PRIORITY: dict[SlotSource, int] = {
|
||||||
|
SlotSource.USER_CONFIRMED: 4,
|
||||||
|
SlotSource.RULE_EXTRACTED: 3,
|
||||||
|
SlotSource.LLM_INFERRED: 2,
|
||||||
|
SlotSource.DEFAULT: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
session: AsyncSession,
|
||||||
|
timeout_governor: TimeoutGovernor | None = None,
|
||||||
|
config: MemoryRecallConfig | None = None,
|
||||||
|
):
|
||||||
|
self._session = session
|
||||||
|
self._timeout_governor = timeout_governor or TimeoutGovernor()
|
||||||
|
self._config = config or MemoryRecallConfig()
|
||||||
|
self._role_provider = RoleBasedFieldProvider(session)
|
||||||
|
|
||||||
|
async def execute(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
session_id: str,
|
||||||
|
recall_scope: list[str] | None = None,
|
||||||
|
max_recent_messages: int | None = None,
|
||||||
|
) -> MemoryRecallResult:
|
||||||
|
"""
|
||||||
|
[AC-IDMP-13] 执行记忆召回。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tenant_id: 租户 ID
|
||||||
|
user_id: 用户 ID
|
||||||
|
session_id: 会话 ID
|
||||||
|
recall_scope: 召回范围,默认 ["profile","facts","preferences","summary","slots"]
|
||||||
|
max_recent_messages: 最大最近消息数,默认 8
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MemoryRecallResult: 记忆召回结果
|
||||||
|
"""
|
||||||
|
if not self._config.enabled:
|
||||||
|
logger.info(f"[AC-IDMP-13] Memory recall disabled for tenant={tenant_id}")
|
||||||
|
return MemoryRecallResult(
|
||||||
|
fallback_reason_code="MEMORY_RECALL_DISABLED",
|
||||||
|
)
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
scope = recall_scope or self._config.default_recall_scope
|
||||||
|
max_msgs = max_recent_messages or self._config.max_recent_messages
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-IDMP-13] Starting memory recall: tenant={tenant_id}, "
|
||||||
|
f"user={user_id}, session={session_id}, scope={scope}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(
|
||||||
|
self._recall_internal(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
scope=scope,
|
||||||
|
max_recent_messages=max_msgs,
|
||||||
|
),
|
||||||
|
timeout=self._config.timeout_ms / 1000.0,
|
||||||
|
)
|
||||||
|
|
||||||
|
duration_ms = int((time.time() - start_time) * 1000)
|
||||||
|
result.duration_ms = duration_ms
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-IDMP-13] Memory recall completed: tenant={tenant_id}, "
|
||||||
|
f"user={user_id}, duration_ms={duration_ms}, "
|
||||||
|
f"profile={bool(result.profile)}, facts={len(result.facts)}, "
|
||||||
|
f"slots={len(result.slots)}, missing_slots={len(result.missing_slots)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
duration_ms = int((time.time() - start_time) * 1000)
|
||||||
|
logger.warning(
|
||||||
|
f"[AC-IDMP-13] Memory recall timeout: tenant={tenant_id}, "
|
||||||
|
f"user={user_id}, duration_ms={duration_ms}"
|
||||||
|
)
|
||||||
|
return MemoryRecallResult(
|
||||||
|
fallback_reason_code="MEMORY_RECALL_TIMEOUT",
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
duration_ms = int((time.time() - start_time) * 1000)
|
||||||
|
logger.error(
|
||||||
|
f"[AC-IDMP-13] Memory recall failed: tenant={tenant_id}, "
|
||||||
|
f"user={user_id}, error={e}"
|
||||||
|
)
|
||||||
|
return MemoryRecallResult(
|
||||||
|
fallback_reason_code=f"MEMORY_RECALL_ERROR:{str(e)[:50]}",
|
||||||
|
duration_ms=duration_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _recall_internal(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
session_id: str,
|
||||||
|
scope: list[str],
|
||||||
|
max_recent_messages: int,
|
||||||
|
) -> MemoryRecallResult:
|
||||||
|
"""内部召回实现。"""
|
||||||
|
profile: dict[str, Any] = {}
|
||||||
|
facts: list[str] = []
|
||||||
|
preferences: dict[str, Any] = {}
|
||||||
|
last_summary: str | None = None
|
||||||
|
slots: dict[str, MemorySlot] = {}
|
||||||
|
missing_slots: list[str] = []
|
||||||
|
|
||||||
|
if "profile" in scope:
|
||||||
|
profile = await self._recall_profile(tenant_id, user_id)
|
||||||
|
|
||||||
|
if "facts" in scope:
|
||||||
|
facts = await self._recall_facts(tenant_id, user_id)
|
||||||
|
|
||||||
|
if "preferences" in scope:
|
||||||
|
preferences = await self._recall_preferences(tenant_id, user_id)
|
||||||
|
|
||||||
|
if "summary" in scope:
|
||||||
|
last_summary = await self._recall_last_summary(tenant_id, user_id)
|
||||||
|
|
||||||
|
if "slots" in scope:
|
||||||
|
slots, missing_slots = await self._recall_slots(
|
||||||
|
tenant_id, user_id, session_id
|
||||||
|
)
|
||||||
|
|
||||||
|
if not profile and not facts and not preferences and not last_summary and not slots:
|
||||||
|
if "history" in scope or max_recent_messages > 0:
|
||||||
|
history_facts = await self._recall_from_history(
|
||||||
|
tenant_id, session_id, max_recent_messages
|
||||||
|
)
|
||||||
|
facts.extend(history_facts)
|
||||||
|
|
||||||
|
return MemoryRecallResult(
|
||||||
|
profile=profile,
|
||||||
|
facts=facts,
|
||||||
|
preferences=preferences,
|
||||||
|
last_summary=last_summary,
|
||||||
|
slots=slots,
|
||||||
|
missing_slots=missing_slots,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _recall_profile(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""召回用户基础属性。"""
|
||||||
|
try:
|
||||||
|
from app.models.entities import ChatMessage
|
||||||
|
from sqlmodel import col
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(ChatMessage)
|
||||||
|
.where(
|
||||||
|
ChatMessage.tenant_id == tenant_id,
|
||||||
|
ChatMessage.role == "user",
|
||||||
|
)
|
||||||
|
.order_by(col(ChatMessage.created_at).desc())
|
||||||
|
.limit(5)
|
||||||
|
)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
messages = result.scalars().all()
|
||||||
|
|
||||||
|
profile: dict[str, Any] = {}
|
||||||
|
for msg in messages:
|
||||||
|
content = msg.content.lower()
|
||||||
|
if "年级" in content or "初" in content or "高" in content:
|
||||||
|
if "grade" not in profile:
|
||||||
|
profile["grade"] = self._extract_grade(msg.content)
|
||||||
|
if "北京" in content or "上海" in content or "广州" in content:
|
||||||
|
if "region" not in profile:
|
||||||
|
profile["region"] = self._extract_region(msg.content)
|
||||||
|
|
||||||
|
return profile
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[AC-IDMP-13] Failed to recall profile: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def _recall_facts(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
) -> list[str]:
|
||||||
|
"""召回用户事实记忆。"""
|
||||||
|
try:
|
||||||
|
from app.models.entities import ChatMessage
|
||||||
|
from sqlmodel import col
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(ChatMessage)
|
||||||
|
.where(
|
||||||
|
ChatMessage.tenant_id == tenant_id,
|
||||||
|
ChatMessage.role == "assistant",
|
||||||
|
)
|
||||||
|
.order_by(col(ChatMessage.created_at).desc())
|
||||||
|
.limit(10)
|
||||||
|
)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
messages = result.scalars().all()
|
||||||
|
|
||||||
|
facts: list[str] = []
|
||||||
|
for msg in messages:
|
||||||
|
content = msg.content
|
||||||
|
if "已购" in content or "购买" in content:
|
||||||
|
facts.append(self._extract_purchase_info(content))
|
||||||
|
if "订单" in content:
|
||||||
|
facts.append(self._extract_order_info(content))
|
||||||
|
|
||||||
|
return [f for f in facts if f]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[AC-IDMP-13] Failed to recall facts: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def _recall_preferences(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""召回用户偏好。"""
|
||||||
|
try:
|
||||||
|
from app.models.entities import ChatMessage
|
||||||
|
from sqlmodel import col
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(ChatMessage)
|
||||||
|
.where(
|
||||||
|
ChatMessage.tenant_id == tenant_id,
|
||||||
|
ChatMessage.role == "user",
|
||||||
|
)
|
||||||
|
.order_by(col(ChatMessage.created_at).desc())
|
||||||
|
.limit(10)
|
||||||
|
)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
messages = result.scalars().all()
|
||||||
|
|
||||||
|
preferences: dict[str, Any] = {}
|
||||||
|
for msg in messages:
|
||||||
|
content = msg.content.lower()
|
||||||
|
if "详细" in content or "详细解释" in content:
|
||||||
|
preferences["communication_style"] = "详细解释"
|
||||||
|
elif "简单" in content or "简洁" in content:
|
||||||
|
preferences["communication_style"] = "简洁"
|
||||||
|
|
||||||
|
return preferences
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[AC-IDMP-13] Failed to recall preferences: {e}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
async def _recall_last_summary(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
) -> str | None:
|
||||||
|
"""召回最近会话摘要。"""
|
||||||
|
try:
|
||||||
|
from app.models.entities import MidAuditLog
|
||||||
|
from sqlmodel import col
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(MidAuditLog)
|
||||||
|
.where(
|
||||||
|
MidAuditLog.tenant_id == tenant_id,
|
||||||
|
)
|
||||||
|
.order_by(col(MidAuditLog.created_at).desc())
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
audit = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if audit:
|
||||||
|
return f"上次会话模式: {audit.mode}"
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[AC-IDMP-13] Failed to recall last summary: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _recall_slots(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
session_id: str,
|
||||||
|
) -> tuple[dict[str, MemorySlot], list[str]]:
|
||||||
|
"""
|
||||||
|
[AC-MRS-12] 召回结构化槽位,只消费 slot 角色的字段。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of (slots_dict, missing_required_slots)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
slot_field_keys = await self._role_provider.get_slot_field_keys(tenant_id)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"[AC-MRS-12] Retrieved {len(slot_field_keys)} slot fields for tenant={tenant_id}: {slot_field_keys}"
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.models.entities import FlowInstance
|
||||||
|
from sqlalchemy import desc
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(FlowInstance)
|
||||||
|
.where(
|
||||||
|
FlowInstance.tenant_id == tenant_id,
|
||||||
|
)
|
||||||
|
.order_by(desc(FlowInstance.updated_at))
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
flow_instance = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
slots: dict[str, MemorySlot] = {}
|
||||||
|
missing_slots: list[str] = []
|
||||||
|
|
||||||
|
if flow_instance and flow_instance.context:
|
||||||
|
context = flow_instance.context
|
||||||
|
for key, value in context.items():
|
||||||
|
if key in slot_field_keys and value is not None:
|
||||||
|
slots[key] = MemorySlot(
|
||||||
|
key=key,
|
||||||
|
value=value,
|
||||||
|
source=SlotSource.USER_CONFIRMED,
|
||||||
|
confidence=1.0,
|
||||||
|
updated_at=str(flow_instance.updated_at),
|
||||||
|
)
|
||||||
|
|
||||||
|
return slots, missing_slots
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[AC-IDMP-13] Failed to recall slots: {e}")
|
||||||
|
return {}, []
|
||||||
|
|
||||||
|
async def _recall_from_history(
|
||||||
|
self,
|
||||||
|
tenant_id: str,
|
||||||
|
session_id: str,
|
||||||
|
max_messages: int,
|
||||||
|
) -> list[str]:
|
||||||
|
"""从最近历史中提取最小回填信息。"""
|
||||||
|
try:
|
||||||
|
from app.models.entities import ChatMessage
|
||||||
|
from sqlmodel import col
|
||||||
|
|
||||||
|
stmt = (
|
||||||
|
select(ChatMessage)
|
||||||
|
.where(
|
||||||
|
ChatMessage.tenant_id == tenant_id,
|
||||||
|
ChatMessage.session_id == session_id,
|
||||||
|
)
|
||||||
|
.order_by(col(ChatMessage.created_at).desc())
|
||||||
|
.limit(max_messages)
|
||||||
|
)
|
||||||
|
result = await self._session.execute(stmt)
|
||||||
|
messages = result.scalars().all()
|
||||||
|
|
||||||
|
facts: list[str] = []
|
||||||
|
for msg in messages:
|
||||||
|
if msg.role == "user":
|
||||||
|
facts.append(f"用户说过: {msg.content[:50]}")
|
||||||
|
|
||||||
|
return facts
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[AC-IDMP-13] Failed to recall from history: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def _extract_grade(self, content: str) -> str:
|
||||||
|
"""从内容中提取年级信息。"""
|
||||||
|
grades = ["初一", "初二", "初三", "高一", "高二", "高三"]
|
||||||
|
for grade in grades:
|
||||||
|
if grade in content:
|
||||||
|
return grade
|
||||||
|
return "未知年级"
|
||||||
|
|
||||||
|
def _extract_region(self, content: str) -> str:
|
||||||
|
"""从内容中提取地区信息。"""
|
||||||
|
regions = ["北京", "上海", "广州", "深圳", "杭州", "成都", "武汉", "南京"]
|
||||||
|
for region in regions:
|
||||||
|
if region in content:
|
||||||
|
return region
|
||||||
|
return "未知地区"
|
||||||
|
|
||||||
|
def _extract_purchase_info(self, content: str) -> str:
|
||||||
|
"""从内容中提取购买信息。"""
|
||||||
|
return f"购买记录: {content[:30]}..."
|
||||||
|
|
||||||
|
def _extract_order_info(self, content: str) -> str:
|
||||||
|
"""从内容中提取订单信息。"""
|
||||||
|
return f"订单信息: {content[:30]}..."
|
||||||
|
|
||||||
|
def merge_slots(
|
||||||
|
self,
|
||||||
|
existing_slots: dict[str, MemorySlot],
|
||||||
|
new_slots: dict[str, MemorySlot],
|
||||||
|
) -> dict[str, MemorySlot]:
|
||||||
|
"""
|
||||||
|
合并槽位,按优先级处理冲突。
|
||||||
|
|
||||||
|
优先级:user_confirmed > rule_extracted > llm_inferred > default
|
||||||
|
"""
|
||||||
|
merged = dict(existing_slots)
|
||||||
|
|
||||||
|
for key, new_slot in new_slots.items():
|
||||||
|
if key not in merged:
|
||||||
|
merged[key] = new_slot
|
||||||
|
else:
|
||||||
|
existing_slot = merged[key]
|
||||||
|
existing_priority = self.SLOT_PRIORITY.get(existing_slot.source, 0)
|
||||||
|
new_priority = self.SLOT_PRIORITY.get(new_slot.source, 0)
|
||||||
|
|
||||||
|
if new_priority > existing_priority:
|
||||||
|
merged[key] = new_slot
|
||||||
|
elif new_priority == existing_priority:
|
||||||
|
if new_slot.confidence > existing_slot.confidence:
|
||||||
|
merged[key] = new_slot
|
||||||
|
|
||||||
|
return merged
|
||||||
|
|
||||||
|
def create_trace(
|
||||||
|
self,
|
||||||
|
result: MemoryRecallResult,
|
||||||
|
tenant_id: str,
|
||||||
|
) -> ToolCallTrace:
|
||||||
|
"""创建工具调用追踪记录。"""
|
||||||
|
status = ToolCallStatus.OK
|
||||||
|
error_code = None
|
||||||
|
|
||||||
|
if result.fallback_reason_code:
|
||||||
|
if "TIMEOUT" in result.fallback_reason_code:
|
||||||
|
status = ToolCallStatus.TIMEOUT
|
||||||
|
else:
|
||||||
|
status = ToolCallStatus.ERROR
|
||||||
|
error_code = result.fallback_reason_code
|
||||||
|
|
||||||
|
return ToolCallTrace(
|
||||||
|
tool_name="memory_recall",
|
||||||
|
tool_type=ToolType.INTERNAL,
|
||||||
|
duration_ms=result.duration_ms,
|
||||||
|
status=status,
|
||||||
|
error_code=error_code,
|
||||||
|
args_digest=f"tenant={tenant_id}",
|
||||||
|
result_digest=f"profile={len(result.profile)}, facts={len(result.facts)}, slots={len(result.slots)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register_memory_recall_tool(
|
||||||
|
registry: Any,
|
||||||
|
session: AsyncSession,
|
||||||
|
timeout_governor: TimeoutGovernor | None = None,
|
||||||
|
config: MemoryRecallConfig | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
[AC-IDMP-13] 注册 memory_recall 工具到 ToolRegistry。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
registry: ToolRegistry 实例
|
||||||
|
session: 数据库会话
|
||||||
|
timeout_governor: 超时治理器
|
||||||
|
config: 工具配置
|
||||||
|
"""
|
||||||
|
cfg = config or MemoryRecallConfig()
|
||||||
|
|
||||||
|
async def memory_recall_handler(
|
||||||
|
tenant_id: str,
|
||||||
|
user_id: str,
|
||||||
|
session_id: str,
|
||||||
|
recall_scope: list[str] | None = None,
|
||||||
|
max_recent_messages: int | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""memory_recall 工具处理器。"""
|
||||||
|
tool = MemoryRecallTool(
|
||||||
|
session=session,
|
||||||
|
timeout_governor=timeout_governor,
|
||||||
|
config=cfg,
|
||||||
|
)
|
||||||
|
result = await tool.execute(
|
||||||
|
tenant_id=tenant_id,
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
recall_scope=recall_scope,
|
||||||
|
max_recent_messages=max_recent_messages,
|
||||||
|
)
|
||||||
|
return result.model_dump()
|
||||||
|
|
||||||
|
registry.register(
|
||||||
|
name="memory_recall",
|
||||||
|
description="[AC-IDMP-13] 记忆召回工具,读取用户可用记忆包(profile/facts/preferences/summary/slots)",
|
||||||
|
handler=memory_recall_handler,
|
||||||
|
tool_type=ToolType.INTERNAL,
|
||||||
|
version="1.0.0",
|
||||||
|
auth_required=False,
|
||||||
|
timeout_ms=min(cfg.timeout_ms, 1000),
|
||||||
|
enabled=True,
|
||||||
|
metadata={
|
||||||
|
"ac_ids": ["AC-IDMP-13"],
|
||||||
|
"recall_scope": cfg.default_recall_scope,
|
||||||
|
"max_recent_messages": cfg.max_recent_messages,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info("[AC-IDMP-13] memory_recall tool registered to registry")
|
||||||
Loading…
Reference in New Issue