[AC-AGENT-ENHANCE] feat(mid): 增强 Agent 编排器和工具
- 优化 agent_orchestrator 的系统提示词指导 - 改进 kb_scene 参数的自动注入逻辑 - 增强 kb_search_dynamic_tool 的元数据处理 - 优化 memory_recall_tool 的记忆召回逻辑 - 更新 memory_adapter 的用户记忆模型
This commit is contained in:
parent
e45396e1e4
commit
6fec2a755a
|
|
@ -35,7 +35,7 @@ from app.models.mid.schemas import (
|
|||
from app.services.llm.base import ToolDefinition
|
||||
from app.services.mid.tool_guide_registry import ToolGuideRegistry, get_tool_guide_registry
|
||||
from app.services.mid.timeout_governor import TimeoutGovernor
|
||||
from app.services.mid.tool_converter import convert_tools_to_llm_format, build_tool_result_message
|
||||
from app.services.mid.tool_converter import convert_tool_to_llm_format, convert_tools_to_llm_format, build_tool_result_message
|
||||
from app.services.prompt.template_service import PromptTemplateService
|
||||
from app.services.prompt.variable_resolver import VariableResolver
|
||||
|
||||
|
|
@ -482,27 +482,35 @@ class AgentOrchestrator:
|
|||
|
||||
**步骤3:调用 kb_search_dynamic 进行搜索**
|
||||
- 使用步骤1获取的元数据字段构造 context 参数
|
||||
- scene 参数必须从元数据字段的 kb_scene 常见值中选择,不要硬编码
|
||||
- scene 参数会自动注入到 context.kb_scene,无需手动在 context 中设置 kb_scene
|
||||
- scene 参数应从元数据字段的 kb_scene 常见值中选择
|
||||
|
||||
**kb_scene 自动注入说明:**
|
||||
- 系统会自动将 scene 参数值注入到 context.kb_scene 字段
|
||||
- AI 只需在 context 中设置其他过滤字段(如 grade、subject)
|
||||
- 不要在 context 中重复设置 kb_scene,系统会自动处理
|
||||
|
||||
**示例流程:**
|
||||
1. 调用 `list_document_metadata_fields` 获取字段信息
|
||||
2. 根据返回结果,发现可用字段:grade(年级)、subject(学科)、kb_scene(场景)
|
||||
3. 分析用户问题"三年级语文怎么学",确定过滤条件:grade="三年级", subject="语文"
|
||||
4. 从 kb_scene 的常见值中选择合适的 scene(如"学习方案")
|
||||
5. 调用 `kb_search_dynamic`,传入构造好的 context 和 scene
|
||||
5. 调用 `kb_search_dynamic`,传入 scene="学习方案",context={"grade": "三年级", "subject": "语文"}
|
||||
6. 系统自动将 scene 注入到 context.kb_scene
|
||||
|
||||
## 注意事项
|
||||
- **严禁**在调用 kb_search_dynamic 之前不调用 list_document_metadata_fields。
|
||||
- **不要**在 context 中手动设置 kb_scene,系统会自动从 scene 参数注入。
|
||||
"""
|
||||
|
||||
if not self._template_service or not self._tenant_id:
|
||||
return default_prompt
|
||||
|
||||
try:
|
||||
from app.core.database import get_session
|
||||
from app.core.database import async_session_maker
|
||||
from app.core.prompts import SYSTEM_PROMPT
|
||||
|
||||
async with get_session() as session:
|
||||
async with async_session_maker() as session:
|
||||
template_service = PromptTemplateService(session)
|
||||
|
||||
base_prompt = await template_service.get_published_template(
|
||||
|
|
@ -511,6 +519,15 @@ class AgentOrchestrator:
|
|||
resolver=self._variable_resolver,
|
||||
)
|
||||
|
||||
if not base_prompt or base_prompt == SYSTEM_PROMPT:
|
||||
base_prompt = await template_service.get_published_template(
|
||||
tenant_id=self._tenant_id,
|
||||
scene="agent_react",
|
||||
resolver=self._variable_resolver,
|
||||
)
|
||||
if base_prompt and base_prompt != SYSTEM_PROMPT:
|
||||
logger.info("[AC-MARH-07] Using agent_react template for Function Calling mode")
|
||||
|
||||
if not base_prompt or base_prompt == SYSTEM_PROMPT:
|
||||
base_prompt = await template_service.get_published_template(
|
||||
tenant_id=self._tenant_id,
|
||||
|
|
@ -519,7 +536,7 @@ class AgentOrchestrator:
|
|||
)
|
||||
|
||||
if not base_prompt or base_prompt == SYSTEM_PROMPT:
|
||||
logger.info("[AC-MARH-07] No published template found for agent_fc or default, using default prompt")
|
||||
logger.info("[AC-MARH-07] No published template found for agent_fc/agent_react/default, using default prompt")
|
||||
return default_prompt
|
||||
|
||||
agent_protocol = """
|
||||
|
|
@ -545,10 +562,15 @@ class AgentOrchestrator:
|
|||
|
||||
**步骤3:调用 kb_search_dynamic 进行搜索**
|
||||
- 使用步骤1获取的元数据字段构造 context 参数
|
||||
- scene 参数必须从元数据字段的 kb_scene 常见值中选择,不要硬编码
|
||||
- scene 参数会自动注入到 context.kb_scene,无需手动在 context 中设置 kb_scene
|
||||
|
||||
**kb_scene 自动注入说明:**
|
||||
- 系统会自动将 scene 参数值注入到 context.kb_scene 字段
|
||||
- AI 只需在 context 中设置其他过滤字段(如 grade、subject)
|
||||
|
||||
## 注意事项
|
||||
- **严禁**在调用 kb_search_dynamic 之前不调用 list_document_metadata_fields。
|
||||
- **不要**在 context 中手动设置 kb_scene,系统会自动从 scene 参数注入。
|
||||
"""
|
||||
|
||||
final_prompt = base_prompt + agent_protocol
|
||||
|
|
|
|||
|
|
@ -127,6 +127,8 @@ class KbSearchDynamicTool:
|
|||
"知识库动态检索工具。"
|
||||
"根据租户配置的元数据字段定义,动态构建检索过滤器。"
|
||||
"支持必填字段检测和可观测降级。"
|
||||
"重要:context 参数中应包含 kb_scene 字段用于场景过滤,"
|
||||
"系统会自动从外部请求的 scene 参数注入到 context.kb_scene。"
|
||||
)
|
||||
|
||||
def get_tool_schema(self) -> dict[str, Any]:
|
||||
|
|
@ -146,7 +148,7 @@ class KbSearchDynamicTool:
|
|||
},
|
||||
"scene": {
|
||||
"type": "string",
|
||||
"description": "场景标识,如 'open_consult', 'intent_match'",
|
||||
"description": "场景标识(如 'open_consult', 'intent_match'),系统会自动将其注入到 context.kb_scene 作为过滤条件",
|
||||
},
|
||||
"top_k": {
|
||||
"type": "integer",
|
||||
|
|
@ -155,7 +157,7 @@ class KbSearchDynamicTool:
|
|||
},
|
||||
"context": {
|
||||
"type": "object",
|
||||
"description": "上下文信息,包含动态过滤字段值",
|
||||
"description": "上下文信息,包含动态过滤字段值。重要字段:kb_scene(场景过滤,由系统自动从 scene 参数注入)、grade(年级)、subject(学科)等",
|
||||
},
|
||||
},
|
||||
"required": ["query"],
|
||||
|
|
@ -299,13 +301,14 @@ class KbSearchDynamicTool:
|
|||
[AC-MARH-05] 执行 KB 动态检索。
|
||||
[AC-MRS-SLOT-META-02] 支持槽位状态聚合和过滤构建优先级
|
||||
[Step-KB-Binding] 支持步骤级别的知识库约束
|
||||
[KB-SCENE-INJECT] 自动将 scene 参数注入到 context.kb_scene
|
||||
|
||||
Args:
|
||||
query: 检索查询
|
||||
tenant_id: 租户 ID
|
||||
scene: 场景标识(默认值,会被 context 中的 scene 覆盖)
|
||||
scene: 场景标识(会自动注入到 context.kb_scene)
|
||||
top_k: 返回数量
|
||||
context: 上下文(包含动态过滤值,包括 scene)
|
||||
context: 上下文(包含动态过滤值)
|
||||
slot_state: 预聚合的槽位状态(可选,优先使用)
|
||||
step_kb_config: 步骤级别的知识库配置(可选)
|
||||
slot_policy: 槽位策略(flow_strict=流程严格模式,agent_relaxed=通用问答宽松模式)
|
||||
|
|
@ -326,6 +329,25 @@ class KbSearchDynamicTool:
|
|||
effective_context = dict(context) if context else {}
|
||||
effective_scene = effective_context.get("scene", scene)
|
||||
|
||||
logger.info(
|
||||
f"[KB-DEBUG] execute() called with: scene='{scene}', context={context}, "
|
||||
f"effective_context_keys={list(effective_context.keys())}"
|
||||
)
|
||||
|
||||
# [KB-SCENE-INJECT] 自动将 scene 参数注入到 context.kb_scene
|
||||
# 优先级:context.kb_scene > context.scene > scene 参数
|
||||
if "kb_scene" not in effective_context and scene:
|
||||
effective_context["kb_scene"] = scene
|
||||
logger.info(
|
||||
f"[KB-SCENE-INJECT] Injected scene='{scene}' into context.kb_scene, "
|
||||
f"effective_context now={effective_context}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"[KB-SCENE-INJECT] Skipped injection: kb_scene in context={('kb_scene' in effective_context)}, "
|
||||
f"scene is empty={not scene}"
|
||||
)
|
||||
|
||||
# [Step-KB-Binding] 记录步骤知识库约束
|
||||
step_kb_binding_info: dict[str, Any] = {}
|
||||
if step_kb_config:
|
||||
|
|
@ -445,8 +467,8 @@ class KbSearchDynamicTool:
|
|||
status=ToolCallStatus.OK,
|
||||
args_digest=f"query={query[:50]}, scene={effective_scene}",
|
||||
result_digest=f"hits={len(hits)}",
|
||||
arguments={"query": query, "scene": effective_scene, "context": context},
|
||||
result={"hits_count": len(hits), "kb_hit": kb_hit},
|
||||
arguments={"query": query, "scene": effective_scene, "context": effective_context},
|
||||
result={"hits_count": len(hits), "kb_hit": kb_hit, "applied_filter": metadata_filter},
|
||||
)
|
||||
|
||||
logger.info(
|
||||
|
|
@ -482,7 +504,7 @@ class KbSearchDynamicTool:
|
|||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.TIMEOUT,
|
||||
error_code="KB_TIMEOUT",
|
||||
arguments={"query": query, "scene": effective_scene, "context": context},
|
||||
arguments={"query": query, "scene": effective_scene, "context": effective_context},
|
||||
)
|
||||
|
||||
return KbSearchDynamicResult(
|
||||
|
|
@ -509,7 +531,7 @@ class KbSearchDynamicTool:
|
|||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.ERROR,
|
||||
error_code="KB_ERROR",
|
||||
arguments={"query": query, "scene": effective_scene, "context": context},
|
||||
arguments={"query": query, "scene": effective_scene, "context": effective_context},
|
||||
)
|
||||
|
||||
return KbSearchDynamicResult(
|
||||
|
|
@ -905,7 +927,7 @@ def register_kb_search_dynamic_tool(
|
|||
|
||||
registry.register(
|
||||
name=KB_SEARCH_DYNAMIC_TOOL_NAME,
|
||||
description="知识库动态检索工具,支持元数据驱动过滤",
|
||||
description="知识库动态检索工具,支持元数据驱动过滤。系统会自动将 scene 参数注入到 context.kb_scene 进行场景过滤。",
|
||||
handler=handler,
|
||||
tool_type=RegistryToolType.INTERNAL,
|
||||
version="1.0.0",
|
||||
|
|
@ -922,9 +944,12 @@ def register_kb_search_dynamic_tool(
|
|||
"properties": {
|
||||
"query": {"type": "string", "description": "检索查询文本"},
|
||||
"tenant_id": {"type": "string", "description": "租户 ID"},
|
||||
"scene": {"type": "string", "description": "场景标识,如 open_consult"},
|
||||
"scene": {"type": "string", "description": "场景标识,系统自动注入到 context.kb_scene"},
|
||||
"top_k": {"type": "integer", "description": "返回条数"},
|
||||
"context": {"type": "object", "description": "上下文,用于动态过滤字段"}
|
||||
"context": {
|
||||
"type": "object",
|
||||
"description": "过滤条件上下文。kb_scene 由系统自动注入,其他字段如 grade、subject 根据用户意图填写"
|
||||
}
|
||||
},
|
||||
"required": ["query", "tenant_id"]
|
||||
},
|
||||
|
|
@ -933,9 +958,10 @@ def register_kb_search_dynamic_tool(
|
|||
"tenant_id": "default",
|
||||
"scene": "open_consult",
|
||||
"top_k": 5,
|
||||
"context": {"product_line": "vip_course", "region": "beijing"}
|
||||
"context": {"grade": "初二", "subject": "数学"}
|
||||
},
|
||||
"result_interpretation": "success=true 且 hits 非空表示命中知识;missing_required_slots 非空时应先向用户补采信息。"
|
||||
"result_interpretation": "success=true 且 hits 非空表示命中知识;missing_required_slots 非空时应先向用户补采信息。",
|
||||
"kb_scene_injection": "系统会自动将 scene 参数值注入到 context.kb_scene 字段,用于知识库场景过滤。AI 无需手动在 context 中设置 kb_scene。"
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ Reference:
|
|||
"""
|
||||
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
|
|
@ -17,6 +19,7 @@ from typing import Any, Callable
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models.entities import UserMemory as UserMemoryEntity
|
||||
from app.models.mid.memory import (
|
||||
MemoryFact,
|
||||
MemoryProfile,
|
||||
|
|
@ -93,14 +96,6 @@ class MemoryAdapter:
|
|||
|
||||
在响应前执行,注入基础属性、事实记忆与偏好记忆。
|
||||
失败时返回空记忆,不阻断主链路。
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
session_id: 会话ID
|
||||
tenant_id: 租户ID(可选)
|
||||
|
||||
Returns:
|
||||
RecallResponse: 包含 profile/facts/preferences 的响应
|
||||
"""
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
|
|
@ -126,9 +121,6 @@ class MemoryAdapter:
|
|||
session_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> RecallResponse:
|
||||
"""
|
||||
内部召回实现
|
||||
"""
|
||||
profile = await self._recall_profile(user_id, tenant_id)
|
||||
facts = await self._recall_facts(user_id, tenant_id)
|
||||
preferences = await self._recall_preferences(user_id, tenant_id)
|
||||
|
|
@ -152,7 +144,6 @@ class MemoryAdapter:
|
|||
user_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> MemoryProfile | None:
|
||||
"""召回用户基础属性"""
|
||||
return MemoryProfile(
|
||||
grade="初一",
|
||||
region="北京",
|
||||
|
|
@ -165,7 +156,6 @@ class MemoryAdapter:
|
|||
user_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> list[MemoryFact]:
|
||||
"""召回用户事实记忆"""
|
||||
return [
|
||||
MemoryFact(content="已购课程:数学思维训练营", source="order", confidence=1.0),
|
||||
MemoryFact(content="学习目标:提高数学成绩", source="profile", confidence=0.9),
|
||||
|
|
@ -177,7 +167,6 @@ class MemoryAdapter:
|
|||
user_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> MemoryPreferences | None:
|
||||
"""召回用户偏好"""
|
||||
return MemoryPreferences(
|
||||
tone="friendly",
|
||||
focus_subjects=["数学", "物理"],
|
||||
|
|
@ -189,8 +178,16 @@ class MemoryAdapter:
|
|||
user_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> str | None:
|
||||
"""召回最近会话摘要"""
|
||||
return "上次讨论了数学学习计划,用户对课程安排比较满意"
|
||||
if not tenant_id:
|
||||
return None
|
||||
|
||||
stmt = select(UserMemoryEntity).where(
|
||||
UserMemoryEntity.tenant_id == tenant_id,
|
||||
UserMemoryEntity.user_id == user_id,
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
record = result.scalar_one_or_none()
|
||||
return record.summary if record else None
|
||||
|
||||
async def update(
|
||||
self,
|
||||
|
|
@ -200,22 +197,6 @@ class MemoryAdapter:
|
|||
summary: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
[AC-IDMP-14] 异步更新用户记忆
|
||||
|
||||
在对话完成后异步执行,不阻塞主响应。
|
||||
包含会话摘要的回写。
|
||||
|
||||
Args:
|
||||
user_id: 用户ID
|
||||
session_id: 会话ID
|
||||
messages: 本轮对话消息
|
||||
summary: 会话摘要(可选)
|
||||
tenant_id: 租户ID
|
||||
|
||||
Returns:
|
||||
bool: 是否成功提交更新任务
|
||||
"""
|
||||
request = UpdateRequest(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
|
|
@ -242,9 +223,6 @@ class MemoryAdapter:
|
|||
request: UpdateRequest,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
内部更新实现
|
||||
"""
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._do_update(request, tenant_id),
|
||||
|
|
@ -270,10 +248,18 @@ class MemoryAdapter:
|
|||
request: UpdateRequest,
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
执行实际的记忆更新
|
||||
"""
|
||||
if request.summary:
|
||||
summary_payload = self._parse_summary_payload(request.summary)
|
||||
if summary_payload:
|
||||
await self._save_summary(
|
||||
request.user_id,
|
||||
summary_payload.get("summary", ""),
|
||||
tenant_id,
|
||||
facts=summary_payload.get("facts"),
|
||||
preferences=summary_payload.get("preferences"),
|
||||
open_issues=summary_payload.get("open_issues"),
|
||||
)
|
||||
else:
|
||||
await self._save_summary(request.user_id, request.summary, tenant_id)
|
||||
|
||||
await self._extract_and_save_facts(
|
||||
|
|
@ -285,9 +271,41 @@ class MemoryAdapter:
|
|||
user_id: str,
|
||||
summary: str,
|
||||
tenant_id: str | None,
|
||||
facts: list[str] | None = None,
|
||||
preferences: dict[str, Any] | None = None,
|
||||
open_issues: list[str] | None = None,
|
||||
) -> None:
|
||||
"""保存会话摘要"""
|
||||
pass
|
||||
if not tenant_id:
|
||||
logger.warning("[AC-IDMP-14] Missing tenant_id when saving summary")
|
||||
return
|
||||
|
||||
stmt = select(UserMemoryEntity).where(
|
||||
UserMemoryEntity.tenant_id == tenant_id,
|
||||
UserMemoryEntity.user_id == user_id,
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
record = result.scalar_one_or_none()
|
||||
|
||||
if record:
|
||||
record.summary = summary
|
||||
record.facts = facts or record.facts
|
||||
record.preferences = preferences or record.preferences
|
||||
record.open_issues = open_issues or record.open_issues
|
||||
record.summary_version = (record.summary_version or 0) + 1
|
||||
record.updated_at = datetime.utcnow()
|
||||
else:
|
||||
record = UserMemoryEntity(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
summary=summary,
|
||||
facts=facts,
|
||||
preferences=preferences,
|
||||
open_issues=open_issues,
|
||||
summary_version=1,
|
||||
)
|
||||
self._session.add(record)
|
||||
|
||||
await self._session.flush()
|
||||
|
||||
async def _extract_and_save_facts(
|
||||
self,
|
||||
|
|
@ -295,8 +313,25 @@ class MemoryAdapter:
|
|||
messages: list[dict[str, Any]],
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""从消息中提取并保存事实"""
|
||||
pass
|
||||
if not tenant_id:
|
||||
return
|
||||
|
||||
for msg in messages:
|
||||
payload = msg.get("memory_payload") or msg.get("summary_payload")
|
||||
if not payload:
|
||||
continue
|
||||
parsed = self._parse_summary_payload(payload)
|
||||
if not parsed:
|
||||
continue
|
||||
await self._save_summary(
|
||||
user_id=user_id,
|
||||
summary=parsed.get("summary", ""),
|
||||
tenant_id=tenant_id,
|
||||
facts=parsed.get("facts"),
|
||||
preferences=parsed.get("preferences"),
|
||||
open_issues=parsed.get("open_issues"),
|
||||
)
|
||||
break
|
||||
|
||||
async def update_with_summary_generation(
|
||||
self,
|
||||
|
|
@ -305,41 +340,92 @@ class MemoryAdapter:
|
|||
messages: list[dict[str, Any]],
|
||||
tenant_id: str | None = None,
|
||||
summary_generator: Callable | None = None,
|
||||
recent_turns: int = 8,
|
||||
) -> bool:
|
||||
"""
|
||||
[AC-IDMP-14] 带摘要生成的记忆更新
|
||||
request = UpdateRequest(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
messages=messages,
|
||||
summary=None,
|
||||
)
|
||||
|
||||
如果未提供摘要,会尝试生成摘要后回写
|
||||
"""
|
||||
task = asyncio.create_task(
|
||||
self._update_with_generation_internal(
|
||||
request,
|
||||
tenant_id,
|
||||
summary_generator,
|
||||
recent_turns,
|
||||
),
|
||||
name=f"memory_update_gen_{user_id}_{session_id}",
|
||||
)
|
||||
self._pending_updates.append(task)
|
||||
task.add_done_callback(lambda t: self._pending_updates.remove(t))
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDMP-14] Memory update (with summary) scheduled for user={user_id}, "
|
||||
f"session={session_id}, messages_count={len(messages)}"
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
async def _update_with_generation_internal(
|
||||
self,
|
||||
request: UpdateRequest,
|
||||
tenant_id: str | None,
|
||||
summary_generator: Callable | None,
|
||||
recent_turns: int,
|
||||
) -> None:
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._do_update_with_generation(
|
||||
request,
|
||||
tenant_id,
|
||||
summary_generator,
|
||||
recent_turns,
|
||||
),
|
||||
timeout=self._update_timeout_ms / 1000,
|
||||
)
|
||||
logger.info(
|
||||
f"[AC-IDMP-14] Memory updated (with summary) for user={request.user_id}, "
|
||||
f"session={request.session_id}"
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
f"[AC-IDMP-14] Memory update (with summary) timeout for user={request.user_id}, "
|
||||
f"session={request.session_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"[AC-IDMP-14] Memory update (with summary) failed for user={request.user_id}, "
|
||||
f"session={request.session_id}, error={e}"
|
||||
)
|
||||
|
||||
async def _do_update_with_generation(
|
||||
self,
|
||||
request: UpdateRequest,
|
||||
tenant_id: str | None,
|
||||
summary_generator: Callable | None,
|
||||
recent_turns: int,
|
||||
) -> None:
|
||||
summary = None
|
||||
if summary_generator:
|
||||
try:
|
||||
summary = await summary_generator(messages)
|
||||
old_summary = await self._load_latest_summary(request.user_id, tenant_id)
|
||||
recent_messages = self._trim_recent_messages(request.messages, recent_turns)
|
||||
summary = await self._call_summary_generator(
|
||||
summary_generator,
|
||||
recent_messages,
|
||||
old_summary,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"[AC-IDMP-14] Summary generation failed: {e}"
|
||||
)
|
||||
|
||||
return await self.update(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
messages=messages,
|
||||
summary=summary,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
request.summary = summary
|
||||
await self._do_update(request, tenant_id)
|
||||
|
||||
async def wait_pending_updates(self, timeout: float = 5.0) -> int:
|
||||
"""
|
||||
等待所有待处理的更新任务完成
|
||||
|
||||
用于优雅关闭时确保所有更新完成
|
||||
|
||||
Args:
|
||||
timeout: 最大等待时间(秒)
|
||||
|
||||
Returns:
|
||||
int: 完成的任务数
|
||||
"""
|
||||
if not self._pending_updates:
|
||||
return 0
|
||||
|
||||
|
|
@ -353,3 +439,62 @@ class MemoryAdapter:
|
|||
except Exception as e:
|
||||
logger.error(f"[AC-IDMP-14] Error waiting for pending updates: {e}")
|
||||
return 0
|
||||
|
||||
async def _load_latest_summary(
|
||||
self,
|
||||
user_id: str,
|
||||
tenant_id: str | None,
|
||||
) -> str | None:
|
||||
if not tenant_id:
|
||||
return None
|
||||
|
||||
stmt = select(UserMemoryEntity).where(
|
||||
UserMemoryEntity.tenant_id == tenant_id,
|
||||
UserMemoryEntity.user_id == user_id,
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
record = result.scalar_one_or_none()
|
||||
return record.summary if record else None
|
||||
|
||||
def _trim_recent_messages(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
recent_turns: int,
|
||||
) -> list[dict[str, Any]]:
|
||||
if recent_turns <= 0:
|
||||
return []
|
||||
return messages[-recent_turns:]
|
||||
|
||||
async def _call_summary_generator(
|
||||
self,
|
||||
summary_generator: Callable,
|
||||
recent_messages: list[dict[str, Any]],
|
||||
old_summary: str | None,
|
||||
) -> str | None:
|
||||
try:
|
||||
if len(inspect.signature(summary_generator).parameters) >= 2:
|
||||
return await summary_generator(recent_messages, old_summary)
|
||||
except Exception:
|
||||
return await summary_generator(recent_messages)
|
||||
|
||||
return await summary_generator(recent_messages)
|
||||
|
||||
def _parse_summary_payload(
|
||||
self,
|
||||
payload: Any,
|
||||
) -> dict[str, Any] | None:
|
||||
if not payload:
|
||||
return None
|
||||
|
||||
if isinstance(payload, dict):
|
||||
return payload
|
||||
|
||||
if isinstance(payload, str):
|
||||
try:
|
||||
parsed = json.loads(payload)
|
||||
if isinstance(parsed, dict):
|
||||
return parsed
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -334,24 +334,20 @@ class MemoryRecallTool:
|
|||
) -> str | None:
|
||||
"""召回最近会话摘要。"""
|
||||
try:
|
||||
from app.models.entities import MidAuditLog
|
||||
from sqlmodel import col
|
||||
from app.models.entities import UserMemory
|
||||
|
||||
stmt = (
|
||||
select(MidAuditLog)
|
||||
select(UserMemory)
|
||||
.where(
|
||||
MidAuditLog.tenant_id == tenant_id,
|
||||
UserMemory.tenant_id == tenant_id,
|
||||
UserMemory.user_id == user_id,
|
||||
)
|
||||
.order_by(col(MidAuditLog.created_at).desc())
|
||||
.limit(1)
|
||||
)
|
||||
result = await self._session.execute(stmt)
|
||||
audit = result.scalar_one_or_none()
|
||||
memory = result.scalar_one_or_none()
|
||||
|
||||
if audit:
|
||||
return f"上次会话模式: {audit.mode}"
|
||||
|
||||
return None
|
||||
return memory.summary if memory else None
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-IDMP-13] Failed to recall last summary: {e}")
|
||||
|
|
|
|||
Loading…
Reference in New Issue