183 lines
5.3 KiB
Python
183 lines
5.3 KiB
Python
|
|
"""
|
|||
|
|
Memory models for Mid Platform.
|
|||
|
|
[AC-IDMP-13] 记忆召回数据模型
|
|||
|
|
[AC-IDMP-14] 记忆更新数据模型
|
|||
|
|
|
|||
|
|
Reference: spec/intent-driven-mid-platform/openapi.deps.yaml
|
|||
|
|
"""
|
|||
|
|
|
|||
|
|
from dataclasses import dataclass, field
|
|||
|
|
from datetime import datetime
|
|||
|
|
from typing import Any
|
|||
|
|
from uuid import UUID
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class MemoryProfile:
|
|||
|
|
"""
|
|||
|
|
[AC-IDMP-13] 用户基础属性记忆
|
|||
|
|
包含年级、地区、渠道等基础信息
|
|||
|
|
"""
|
|||
|
|
grade: str | None = None
|
|||
|
|
region: str | None = None
|
|||
|
|
channel: str | None = None
|
|||
|
|
vip_level: str | None = None
|
|||
|
|
registration_date: datetime | None = None
|
|||
|
|
extra: dict[str, Any] = field(default_factory=dict)
|
|||
|
|
|
|||
|
|
def to_dict(self) -> dict[str, Any]:
|
|||
|
|
result = {}
|
|||
|
|
if self.grade:
|
|||
|
|
result["grade"] = self.grade
|
|||
|
|
if self.region:
|
|||
|
|
result["region"] = self.region
|
|||
|
|
if self.channel:
|
|||
|
|
result["channel"] = self.channel
|
|||
|
|
if self.vip_level:
|
|||
|
|
result["vip_level"] = self.vip_level
|
|||
|
|
if self.registration_date:
|
|||
|
|
result["registration_date"] = self.registration_date.isoformat()
|
|||
|
|
if self.extra:
|
|||
|
|
result.update(self.extra)
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class MemoryFact:
|
|||
|
|
"""
|
|||
|
|
[AC-IDMP-13] 事实型记忆
|
|||
|
|
包含已购课程、学习结论等客观事实
|
|||
|
|
"""
|
|||
|
|
content: str
|
|||
|
|
source: str | None = None
|
|||
|
|
confidence: float | None = None
|
|||
|
|
created_at: datetime | None = None
|
|||
|
|
expires_at: datetime | None = None
|
|||
|
|
|
|||
|
|
def to_string(self) -> str:
|
|||
|
|
return self.content
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class MemoryPreferences:
|
|||
|
|
"""
|
|||
|
|
[AC-IDMP-13] 偏好记忆
|
|||
|
|
包含语气偏好、关注科目等用户偏好
|
|||
|
|
"""
|
|||
|
|
tone: str | None = None
|
|||
|
|
focus_subjects: list[str] = field(default_factory=list)
|
|||
|
|
communication_style: str | None = None
|
|||
|
|
preferred_time: str | None = None
|
|||
|
|
extra: dict[str, Any] = field(default_factory=dict)
|
|||
|
|
|
|||
|
|
def to_dict(self) -> dict[str, Any]:
|
|||
|
|
result = {}
|
|||
|
|
if self.tone:
|
|||
|
|
result["tone"] = self.tone
|
|||
|
|
if self.focus_subjects:
|
|||
|
|
result["focus_subjects"] = self.focus_subjects
|
|||
|
|
if self.communication_style:
|
|||
|
|
result["communication_style"] = self.communication_style
|
|||
|
|
if self.preferred_time:
|
|||
|
|
result["preferred_time"] = self.preferred_time
|
|||
|
|
if self.extra:
|
|||
|
|
result.update(self.extra)
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class RecallRequest:
|
|||
|
|
"""
|
|||
|
|
[AC-IDMP-13] 记忆召回请求
|
|||
|
|
Reference: openapi.deps.yaml - RecallRequest
|
|||
|
|
"""
|
|||
|
|
user_id: str
|
|||
|
|
session_id: str
|
|||
|
|
|
|||
|
|
def to_dict(self) -> dict[str, Any]:
|
|||
|
|
return {
|
|||
|
|
"user_id": self.user_id,
|
|||
|
|
"session_id": self.session_id,
|
|||
|
|
}
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class RecallResponse:
|
|||
|
|
"""
|
|||
|
|
[AC-IDMP-13] 记忆召回响应
|
|||
|
|
Reference: openapi.deps.yaml - RecallResponse
|
|||
|
|
"""
|
|||
|
|
profile: MemoryProfile | None = None
|
|||
|
|
facts: list[MemoryFact] = field(default_factory=list)
|
|||
|
|
preferences: MemoryPreferences | None = None
|
|||
|
|
last_summary: str | None = None
|
|||
|
|
|
|||
|
|
def to_dict(self) -> dict[str, Any]:
|
|||
|
|
result = {}
|
|||
|
|
if self.profile:
|
|||
|
|
result["profile"] = self.profile.to_dict()
|
|||
|
|
if self.facts:
|
|||
|
|
result["facts"] = [f.to_string() for f in self.facts]
|
|||
|
|
if self.preferences:
|
|||
|
|
result["preferences"] = self.preferences.to_dict()
|
|||
|
|
if self.last_summary:
|
|||
|
|
result["last_summary"] = self.last_summary
|
|||
|
|
return result
|
|||
|
|
|
|||
|
|
def get_context_for_prompt(self) -> str:
|
|||
|
|
"""
|
|||
|
|
生成用于注入 Prompt 的上下文字符串
|
|||
|
|
"""
|
|||
|
|
parts = []
|
|||
|
|
|
|||
|
|
if self.profile:
|
|||
|
|
profile_parts = []
|
|||
|
|
if self.profile.grade:
|
|||
|
|
profile_parts.append(f"年级: {self.profile.grade}")
|
|||
|
|
if self.profile.region:
|
|||
|
|
profile_parts.append(f"地区: {self.profile.region}")
|
|||
|
|
if self.profile.vip_level:
|
|||
|
|
profile_parts.append(f"会员等级: {self.profile.vip_level}")
|
|||
|
|
if profile_parts:
|
|||
|
|
parts.append("【用户属性】" + "、".join(profile_parts))
|
|||
|
|
|
|||
|
|
if self.facts:
|
|||
|
|
fact_strs = [f.content for f in self.facts[:5]]
|
|||
|
|
parts.append("【已知事实】" + ";".join(fact_strs))
|
|||
|
|
|
|||
|
|
if self.preferences:
|
|||
|
|
pref_parts = []
|
|||
|
|
if self.preferences.tone:
|
|||
|
|
pref_parts.append(f"语气偏好: {self.preferences.tone}")
|
|||
|
|
if self.preferences.focus_subjects:
|
|||
|
|
pref_parts.append(f"关注科目: {', '.join(self.preferences.focus_subjects)}")
|
|||
|
|
if pref_parts:
|
|||
|
|
parts.append("【用户偏好】" + "、".join(pref_parts))
|
|||
|
|
|
|||
|
|
if self.last_summary:
|
|||
|
|
parts.append(f"【上次会话摘要】{self.last_summary}")
|
|||
|
|
|
|||
|
|
return "\n".join(parts) if parts else ""
|
|||
|
|
|
|||
|
|
|
|||
|
|
@dataclass
|
|||
|
|
class UpdateRequest:
|
|||
|
|
"""
|
|||
|
|
[AC-IDMP-14] 记忆更新请求
|
|||
|
|
Reference: openapi.deps.yaml - UpdateRequest
|
|||
|
|
"""
|
|||
|
|
user_id: str
|
|||
|
|
session_id: str
|
|||
|
|
messages: list[dict[str, Any]]
|
|||
|
|
summary: str | None = None
|
|||
|
|
|
|||
|
|
def to_dict(self) -> dict[str, Any]:
|
|||
|
|
result = {
|
|||
|
|
"user_id": self.user_id,
|
|||
|
|
"session_id": self.session_id,
|
|||
|
|
"messages": self.messages,
|
|||
|
|
}
|
|||
|
|
if self.summary:
|
|||
|
|
result["summary"] = self.summary
|
|||
|
|
return result
|