feat: enhance agent orchestrator with runtime hardening and tool governance [AC-MARH-01~12]
This commit is contained in:
parent
66902cd7c1
commit
d78b72ca93
|
|
@ -43,6 +43,8 @@ class ToolCallTrace:
|
|||
- error_code: 错误码
|
||||
- args_digest: 参数摘要(脱敏)
|
||||
- result_digest: 结果摘要
|
||||
- arguments: 完整参数
|
||||
- result: 完整结果
|
||||
"""
|
||||
tool_name: str
|
||||
duration_ms: int
|
||||
|
|
@ -53,6 +55,8 @@ class ToolCallTrace:
|
|||
error_code: str | None = None
|
||||
args_digest: str | None = None
|
||||
result_digest: str | None = None
|
||||
arguments: dict[str, Any] | None = None
|
||||
result: Any = None
|
||||
started_at: datetime = field(default_factory=datetime.utcnow)
|
||||
completed_at: datetime | None = None
|
||||
|
||||
|
|
@ -74,6 +78,10 @@ class ToolCallTrace:
|
|||
result["args_digest"] = self.args_digest
|
||||
if self.result_digest:
|
||||
result["result_digest"] = self.result_digest
|
||||
if self.arguments:
|
||||
result["arguments"] = self.arguments
|
||||
if self.result is not None:
|
||||
result["result"] = self.result
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -2,6 +2,10 @@
|
|||
Agent Orchestrator for Mid Platform.
|
||||
[AC-MARH-07] ReAct loop with iteration limit (3-5 iterations).
|
||||
|
||||
Supports two execution modes:
|
||||
1. ReAct (Text-based): Traditional Thought/Action/Observation loop
|
||||
2. Function Calling: Uses LLM's native function calling capability
|
||||
|
||||
ReAct Flow:
|
||||
1. Thought: Agent thinks about what to do
|
||||
2. Action: Agent decides to use a tool
|
||||
|
|
@ -16,6 +20,8 @@ import re
|
|||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
from app.models.mid.schemas import (
|
||||
|
|
@ -26,7 +32,10 @@ from app.models.mid.schemas import (
|
|||
ToolType,
|
||||
TraceInfo,
|
||||
)
|
||||
from app.services.llm.base import ToolDefinition
|
||||
from app.services.mid.tool_guide_registry import ToolGuideRegistry, get_tool_guide_registry
|
||||
from app.services.mid.timeout_governor import TimeoutGovernor
|
||||
from app.services.mid.tool_converter import convert_tools_to_llm_format, build_tool_result_message
|
||||
from app.services.prompt.template_service import PromptTemplateService
|
||||
from app.services.prompt.variable_resolver import VariableResolver
|
||||
|
||||
|
|
@ -36,11 +45,17 @@ DEFAULT_MAX_ITERATIONS = 5
|
|||
MIN_ITERATIONS = 3
|
||||
|
||||
|
||||
class AgentMode(str, Enum):
|
||||
"""Agent execution mode."""
|
||||
REACT = "react"
|
||||
FUNCTION_CALLING = "function_calling"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolResult:
|
||||
"""Tool execution result."""
|
||||
success: bool
|
||||
output: str | None = None
|
||||
output: Any = None
|
||||
error: str | None = None
|
||||
duration_ms: int = 0
|
||||
|
||||
|
|
@ -59,6 +74,7 @@ class AgentOrchestrator:
|
|||
|
||||
Features:
|
||||
- ReAct loop with max 5 iterations (min 3)
|
||||
- Function Calling mode for supported LLMs (OpenAI, DeepSeek, etc.)
|
||||
- Per-tool timeout (2s) and end-to-end timeout (8s)
|
||||
- Automatic fallback on iteration limit or timeout
|
||||
- Template-based prompt with variable injection
|
||||
|
|
@ -70,17 +86,74 @@ class AgentOrchestrator:
|
|||
timeout_governor: TimeoutGovernor | None = None,
|
||||
llm_client: Any = None,
|
||||
tool_registry: Any = None,
|
||||
guide_registry: ToolGuideRegistry | None = None,
|
||||
template_service: PromptTemplateService | None = None,
|
||||
variable_resolver: VariableResolver | None = None,
|
||||
tenant_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
scene: str | None = None,
|
||||
mode: AgentMode = AgentMode.FUNCTION_CALLING,
|
||||
):
|
||||
self._max_iterations = max(min(max_iterations, 5), MIN_ITERATIONS)
|
||||
self._timeout_governor = timeout_governor or TimeoutGovernor()
|
||||
self._llm_client = llm_client
|
||||
self._tool_registry = tool_registry
|
||||
self._guide_registry = guide_registry
|
||||
self._template_service = template_service
|
||||
self._variable_resolver = variable_resolver or VariableResolver()
|
||||
self._tenant_id = tenant_id
|
||||
self._user_id = user_id
|
||||
self._session_id = session_id
|
||||
self._scene = scene
|
||||
self._mode = mode
|
||||
self._tools_cache: list[ToolDefinition] | None = None
|
||||
|
||||
def _get_tools_definition(self) -> list[ToolDefinition]:
|
||||
"""Get cached tools definition for Function Calling."""
|
||||
if self._tools_cache is None and self._tool_registry:
|
||||
tools = self._tool_registry.get_all_tools()
|
||||
self._tools_cache = convert_tools_to_llm_format(tools)
|
||||
return self._tools_cache or []
|
||||
|
||||
async def _get_tools_definition_async(self) -> list[ToolDefinition]:
|
||||
"""Get tools definition for Function Calling with dynamic schema support."""
|
||||
if self._tools_cache is not None:
|
||||
return self._tools_cache
|
||||
|
||||
if not self._tool_registry:
|
||||
return []
|
||||
|
||||
tools = self._tool_registry.get_all_tools()
|
||||
result = []
|
||||
|
||||
for tool in tools:
|
||||
if tool.name == "kb_search_dynamic" and self._tenant_id:
|
||||
from app.services.mid.kb_search_dynamic_tool import (
|
||||
_TOOL_SCHEMA_CACHE,
|
||||
_TOOL_SCHEMA_CACHE_TTL_SECONDS,
|
||||
)
|
||||
import time
|
||||
|
||||
cache_key = f"tool_schema:{self._tenant_id}"
|
||||
current_time = time.time()
|
||||
|
||||
if cache_key in _TOOL_SCHEMA_CACHE:
|
||||
cached_time, cached_schema = _TOOL_SCHEMA_CACHE[cache_key]
|
||||
if current_time - cached_time < _TOOL_SCHEMA_CACHE_TTL_SECONDS:
|
||||
result.append(ToolDefinition(
|
||||
name=cached_schema["name"],
|
||||
description=cached_schema["description"],
|
||||
parameters=cached_schema["parameters"],
|
||||
))
|
||||
continue
|
||||
|
||||
result.append(convert_tool_to_llm_format(tool))
|
||||
else:
|
||||
result.append(convert_tool_to_llm_format(tool))
|
||||
|
||||
self._tools_cache = result
|
||||
return result
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
|
|
@ -90,7 +163,7 @@ class AgentOrchestrator:
|
|||
on_action: Any = None,
|
||||
) -> tuple[str, ReActContext, TraceInfo]:
|
||||
"""
|
||||
[AC-MARH-07] Execute ReAct loop with iteration control.
|
||||
[AC-MARH-07] Execute agent loop with iteration control.
|
||||
|
||||
Args:
|
||||
user_message: User input message
|
||||
|
|
@ -101,6 +174,416 @@ class AgentOrchestrator:
|
|||
Returns:
|
||||
Tuple of (final_answer, react_context, trace_info)
|
||||
"""
|
||||
if self._mode == AgentMode.FUNCTION_CALLING:
|
||||
return await self._execute_function_calling(user_message, context, on_action)
|
||||
else:
|
||||
return await self._execute_react(user_message, context, on_thought, on_action)
|
||||
|
||||
async def _execute_function_calling(
|
||||
self,
|
||||
user_message: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
on_action: Any = None,
|
||||
) -> tuple[str, ReActContext, TraceInfo]:
|
||||
"""
|
||||
Execute using Function Calling mode.
|
||||
|
||||
This mode uses the LLM's native function calling capability,
|
||||
which is more reliable and token-efficient than text-based ReAct.
|
||||
"""
|
||||
react_ctx = ReActContext(max_iterations=self._max_iterations)
|
||||
tool_calls: list[ToolCallTrace] = []
|
||||
start_time = time.time()
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-07] Starting Function Calling loop: max_iterations={self._max_iterations}, "
|
||||
f"llm_client={self._llm_client is not None}, tool_registry={self._tool_registry is not None}"
|
||||
)
|
||||
|
||||
if not self._llm_client:
|
||||
logger.error("[DEBUG-ORCH] LLM client is None, returning error response")
|
||||
return "抱歉,服务配置错误,请联系管理员。", react_ctx, TraceInfo(
|
||||
mode=ExecutionMode.AGENT,
|
||||
request_id=str(uuid.uuid4()),
|
||||
generation_id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
try:
|
||||
messages: list[dict[str, Any]] = [
|
||||
{"role": "system", "content": await self._build_system_prompt()},
|
||||
{"role": "user", "content": user_message},
|
||||
]
|
||||
|
||||
tools = await self._get_tools_definition_async()
|
||||
|
||||
overall_start = time.time()
|
||||
end_to_end_timeout = self._timeout_governor.end_to_end_timeout_seconds
|
||||
llm_timeout = getattr(self._timeout_governor, 'llm_timeout_seconds', 15.0)
|
||||
|
||||
while react_ctx.should_continue and react_ctx.iteration < react_ctx.max_iterations:
|
||||
react_ctx.iteration += 1
|
||||
|
||||
elapsed = time.time() - overall_start
|
||||
remaining_time = end_to_end_timeout - elapsed
|
||||
if remaining_time <= 0:
|
||||
logger.warning("[AC-MARH-09] Function Calling loop exceeded end-to-end timeout")
|
||||
react_ctx.final_answer = "抱歉,处理超时,请稍后重试或联系人工客服。"
|
||||
break
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-07] Function Calling iteration {react_ctx.iteration}/"
|
||||
f"{react_ctx.max_iterations}, remaining_time={remaining_time:.1f}s"
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[DEBUG-ORCH] Calling LLM generate with messages_count={len(messages)}, "
|
||||
f"tools_count={len(tools) if tools else 0}"
|
||||
)
|
||||
|
||||
response = await asyncio.wait_for(
|
||||
self._llm_client.generate(
|
||||
messages=messages,
|
||||
tools=tools if tools else None,
|
||||
tool_choice="auto" if tools else None,
|
||||
),
|
||||
timeout=min(llm_timeout, remaining_time)
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[DEBUG-ORCH] LLM response received: has_tool_calls={response.has_tool_calls}, "
|
||||
f"content_length={len(response.content) if response.content else 0}, "
|
||||
f"tool_calls_count={len(response.tool_calls) if response.tool_calls else 0}"
|
||||
)
|
||||
|
||||
if response.has_tool_calls:
|
||||
for tool_call in response.tool_calls:
|
||||
tool_name = tool_call.name
|
||||
tool_args = tool_call.arguments
|
||||
|
||||
logger.info(f"[AC-MARH-07] Tool call: {tool_name}, args={tool_args}")
|
||||
|
||||
messages.append({
|
||||
"role": "assistant",
|
||||
"content": response.content,
|
||||
"tool_calls": [tool_call.to_dict()],
|
||||
})
|
||||
|
||||
tool_result, tool_trace = await self._act_fc(
|
||||
tool_call.id, tool_name, tool_args, react_ctx
|
||||
)
|
||||
tool_calls.append(tool_trace)
|
||||
react_ctx.tool_calls.append(tool_trace)
|
||||
|
||||
if on_action:
|
||||
await on_action(tool_name, tool_result)
|
||||
|
||||
called_tools = {tc.tool_name for tc in react_ctx.tool_calls[:-1]}
|
||||
is_first_call = tool_name not in called_tools
|
||||
|
||||
# Extract tool_guide from output if present (added by _act_fc)
|
||||
result_output = tool_result.output if tool_result.success else {"error": tool_result.error}
|
||||
tool_guide = None
|
||||
if isinstance(result_output, dict) and "_tool_guide" in result_output:
|
||||
result_output = dict(result_output)
|
||||
tool_guide = result_output.pop("_tool_guide")
|
||||
|
||||
messages.append(build_tool_result_message(
|
||||
tool_call_id=tool_call.id,
|
||||
tool_name=tool_name,
|
||||
result=result_output,
|
||||
tool_guide=tool_guide,
|
||||
))
|
||||
|
||||
if not tool_result.success:
|
||||
if tool_trace.status == ToolCallStatus.TIMEOUT:
|
||||
react_ctx.final_answer = "抱歉,操作超时,请稍后重试或联系人工客服。"
|
||||
react_ctx.should_continue = False
|
||||
break
|
||||
else:
|
||||
react_ctx.final_answer = response.content or "抱歉,我无法处理您的请求。"
|
||||
react_ctx.should_continue = False
|
||||
break
|
||||
|
||||
if react_ctx.should_continue and not react_ctx.final_answer:
|
||||
logger.warning(f"[AC-MARH-07] Function Calling reached max iterations: {react_ctx.iteration}")
|
||||
react_ctx.final_answer = await self._force_final_answer_fc(messages)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
logger.error("[AC-MARH-09] Function Calling loop timed out (end-to-end)")
|
||||
react_ctx.final_answer = "抱歉,处理超时,请稍后重试或联系人工客服。"
|
||||
tool_calls.append(ToolCallTrace(
|
||||
tool_name="fc_loop",
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=int((time.time() - start_time) * 1000),
|
||||
status=ToolCallStatus.TIMEOUT,
|
||||
error_code="E2E_TIMEOUT",
|
||||
))
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-MARH-07] Function Calling error: {e}", exc_info=True)
|
||||
react_ctx.final_answer = f"抱歉,处理过程中发生错误:{str(e)}"
|
||||
|
||||
total_duration_ms = int((time.time() - start_time) * 1000)
|
||||
trace = TraceInfo(
|
||||
mode=ExecutionMode.AGENT,
|
||||
request_id=str(uuid.uuid4()),
|
||||
generation_id=str(uuid.uuid4()),
|
||||
react_iterations=react_ctx.iteration,
|
||||
tools_used=[tc.tool_name for tc in tool_calls if tc.tool_name not in ("fc_loop", "react_loop")],
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-07] Function Calling completed: iterations={react_ctx.iteration}, "
|
||||
f"duration_ms={total_duration_ms}"
|
||||
)
|
||||
|
||||
return react_ctx.final_answer or "抱歉,我暂时无法处理您的请求。", react_ctx, trace
|
||||
|
||||
async def _act_fc(
|
||||
self,
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any],
|
||||
react_ctx: ReActContext,
|
||||
) -> tuple[ToolResult, ToolCallTrace]:
|
||||
"""Execute tool in Function Calling mode."""
|
||||
start_time = time.time()
|
||||
|
||||
if not self._tool_registry:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="Tool registry not configured",
|
||||
duration_ms=duration_ms,
|
||||
), ToolCallTrace(
|
||||
tool_name=tool_name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.ERROR,
|
||||
error_code="NO_REGISTRY",
|
||||
)
|
||||
|
||||
try:
|
||||
final_args = dict(tool_args)
|
||||
if self._tenant_id:
|
||||
final_args["tenant_id"] = self._tenant_id
|
||||
if self._user_id:
|
||||
final_args["user_id"] = self._user_id
|
||||
if self._session_id:
|
||||
final_args["session_id"] = self._session_id
|
||||
|
||||
if tool_name == "kb_search_dynamic":
|
||||
# 确保 context 存在,供 AI 传入动态过滤条件
|
||||
if "context" not in final_args:
|
||||
final_args["context"] = {}
|
||||
# scene 参数由 AI 从元数据中选择,系统不强制覆盖
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-07] FC Tool call starting: tool={tool_name}, "
|
||||
f"args={tool_args}, final_args={final_args}"
|
||||
)
|
||||
|
||||
result = await asyncio.wait_for(
|
||||
self._tool_registry.execute(
|
||||
name=tool_name,
|
||||
**final_args,
|
||||
),
|
||||
timeout=self._timeout_governor.per_tool_timeout_seconds
|
||||
)
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
called_tools = {tc.tool_name for tc in react_ctx.tool_calls}
|
||||
is_first_call = tool_name not in called_tools
|
||||
|
||||
output = result.output
|
||||
if is_first_call and result.success:
|
||||
usage_guide = self._build_tool_usage_guide(tool_name)
|
||||
if usage_guide:
|
||||
if isinstance(output, dict):
|
||||
output = dict(output)
|
||||
output["_tool_guide"] = usage_guide
|
||||
elif isinstance(output, str):
|
||||
output = f"{output}\n\n---\n{usage_guide}"
|
||||
else:
|
||||
output = {"result": output, "_tool_guide": usage_guide}
|
||||
|
||||
return ToolResult(
|
||||
success=result.success,
|
||||
output=output,
|
||||
error=result.error,
|
||||
duration_ms=duration_ms,
|
||||
), ToolCallTrace(
|
||||
tool_name=tool_name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.OK if result.success else ToolCallStatus.ERROR,
|
||||
args_digest=str(tool_args)[:100] if tool_args else None,
|
||||
result_digest=str(result.output)[:100] if result.output else None,
|
||||
arguments=tool_args,
|
||||
result=output,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.warning(f"[AC-MARH-08] FC Tool timeout: {tool_name}, duration={duration_ms}ms")
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error="Tool timeout",
|
||||
duration_ms=duration_ms,
|
||||
), ToolCallTrace(
|
||||
tool_name=tool_name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.TIMEOUT,
|
||||
error_code="TOOL_TIMEOUT",
|
||||
arguments=tool_args,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(f"[AC-MARH-07] FC Tool error: {tool_name}, error={e}")
|
||||
return ToolResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
duration_ms=duration_ms,
|
||||
), ToolCallTrace(
|
||||
tool_name=tool_name,
|
||||
tool_type=ToolType.INTERNAL,
|
||||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.ERROR,
|
||||
error_code="TOOL_ERROR",
|
||||
arguments=tool_args,
|
||||
)
|
||||
|
||||
async def _build_system_prompt(self) -> str:
|
||||
"""Build system prompt for Function Calling mode with template support."""
|
||||
default_prompt = """你是一个智能客服助手,正在处理用户请求。
|
||||
|
||||
## 决策协议
|
||||
|
||||
1. 优先使用已有观察信息,避免重复调用同类工具。
|
||||
2. 当问题需要外部事实或结构化状态时再调用工具;如果可直接回答则不要调用。
|
||||
3. 缺少关键参数时,优先向用户追问,不要使用空参数调用工具。
|
||||
4. 工具失败时,先说明已尝试,再给出降级方案或下一步引导。
|
||||
5. 对用户输出必须拟人、自然、有同理心,不暴露"工具调用/路由/策略"等内部术语。
|
||||
|
||||
## 知识库查询强制流程
|
||||
|
||||
当用户问题需要进行知识库查询(kb_search_dynamic)时,必须遵循以下步骤:
|
||||
|
||||
**步骤1:先调用 list_document_metadata_fields**
|
||||
- 在任何知识库搜索之前,必须先调用 `list_document_metadata_fields` 工具
|
||||
- 获取可用的元数据字段(如 grade, subject, kb_scene 等)及其常见取值
|
||||
|
||||
**步骤2:分析用户意图,选择合适的过滤条件**
|
||||
- 根据用户问题和返回的元数据字段,确定合适的过滤条件
|
||||
- 从元数据字段的 common_values 中选择合适的值
|
||||
|
||||
**步骤3:调用 kb_search_dynamic 进行搜索**
|
||||
- 使用步骤1获取的元数据字段构造 context 参数
|
||||
- scene 参数必须从元数据字段的 kb_scene 常见值中选择,不要硬编码
|
||||
|
||||
**示例流程:**
|
||||
1. 调用 `list_document_metadata_fields` 获取字段信息
|
||||
2. 根据返回结果,发现可用字段:grade(年级)、subject(学科)、kb_scene(场景)
|
||||
3. 分析用户问题"三年级语文怎么学",确定过滤条件:grade="三年级", subject="语文"
|
||||
4. 从 kb_scene 的常见值中选择合适的 scene(如"学习方案")
|
||||
5. 调用 `kb_search_dynamic`,传入构造好的 context 和 scene
|
||||
|
||||
## 注意事项
|
||||
- **严禁**在调用 kb_search_dynamic 之前不调用 list_document_metadata_fields。
|
||||
"""
|
||||
|
||||
if not self._template_service or not self._tenant_id:
|
||||
return default_prompt
|
||||
|
||||
try:
|
||||
from app.core.database import get_session
|
||||
from app.core.prompts import SYSTEM_PROMPT
|
||||
|
||||
async with get_session() as session:
|
||||
template_service = PromptTemplateService(session)
|
||||
|
||||
base_prompt = await template_service.get_published_template(
|
||||
tenant_id=self._tenant_id,
|
||||
scene="agent_fc",
|
||||
resolver=self._variable_resolver,
|
||||
)
|
||||
|
||||
if not base_prompt or base_prompt == SYSTEM_PROMPT:
|
||||
base_prompt = await template_service.get_published_template(
|
||||
tenant_id=self._tenant_id,
|
||||
scene="default",
|
||||
resolver=self._variable_resolver,
|
||||
)
|
||||
|
||||
if not base_prompt or base_prompt == SYSTEM_PROMPT:
|
||||
logger.info("[AC-MARH-07] No published template found for agent_fc or default, using default prompt")
|
||||
return default_prompt
|
||||
|
||||
agent_protocol = """
|
||||
|
||||
## 智能体决策协议
|
||||
|
||||
1. 优先使用已有观察信息,避免重复调用同类工具。
|
||||
2. 当问题需要外部事实或结构化状态时再调用工具;如果可直接回答则不要调用。
|
||||
3. 缺少关键参数时,优先向用户追问,不要使用空参数调用工具。
|
||||
4. 工具失败时,先说明已尝试,再给出降级方案或下一步引导。
|
||||
|
||||
## 知识库查询强制流程
|
||||
|
||||
当用户问题需要进行知识库查询(kb_search_dynamic)时,必须遵循以下步骤:
|
||||
|
||||
**步骤1:先调用 list_document_metadata_fields**
|
||||
- 在任何知识库搜索之前,必须先调用 `list_document_metadata_fields` 工具
|
||||
- 获取可用的元数据字段(如 grade, subject, kb_scene 等)及其常见取值
|
||||
|
||||
**步骤2:分析用户意图,选择合适的过滤条件**
|
||||
- 根据用户问题和返回的元数据字段,确定合适的过滤条件
|
||||
- 从元数据字段的 common_values 中选择合适的值
|
||||
|
||||
**步骤3:调用 kb_search_dynamic 进行搜索**
|
||||
- 使用步骤1获取的元数据字段构造 context 参数
|
||||
- scene 参数必须从元数据字段的 kb_scene 常见值中选择,不要硬编码
|
||||
|
||||
## 注意事项
|
||||
- **严禁**在调用 kb_search_dynamic 之前不调用 list_document_metadata_fields。
|
||||
"""
|
||||
|
||||
final_prompt = base_prompt + agent_protocol
|
||||
|
||||
logger.info(f"[AC-MARH-07] Loaded template for tenant={self._tenant_id}")
|
||||
return final_prompt
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(f"[AC-MARH-07] Failed to load template, using default: {e}")
|
||||
return default_prompt
|
||||
|
||||
async def _force_final_answer_fc(self, messages: list[dict[str, Any]]) -> str:
|
||||
"""Force final answer when max iterations reached in Function Calling mode."""
|
||||
try:
|
||||
response = await self._llm_client.generate(
|
||||
messages=messages + [{"role": "user", "content": "请基于以上信息给出最终回答,不要再调用工具。"}],
|
||||
tools=None,
|
||||
)
|
||||
return response.content or "抱歉,我已经尽力处理您的请求,但可能需要更多信息。"
|
||||
except Exception as e:
|
||||
logger.error(f"[AC-MARH-07] Force final answer FC failed: {e}")
|
||||
return "抱歉,我已经尽力处理您的请求,但可能需要更多信息。请稍后重试或联系人工客服。"
|
||||
|
||||
async def _execute_react(
|
||||
self,
|
||||
user_message: str,
|
||||
context: dict[str, Any] | None = None,
|
||||
on_thought: Any = None,
|
||||
on_action: Any = None,
|
||||
) -> tuple[str, ReActContext, TraceInfo]:
|
||||
"""
|
||||
Execute using traditional ReAct mode (text-based).
|
||||
|
||||
This is the original implementation for backward compatibility.
|
||||
"""
|
||||
react_ctx = ReActContext(max_iterations=self._max_iterations)
|
||||
tool_calls: list[ToolCallTrace] = []
|
||||
start_time = time.time()
|
||||
|
|
@ -321,58 +804,108 @@ Action Input:
|
|||
return default_template
|
||||
|
||||
def _build_tools_section(self) -> str:
|
||||
"""Build rich tools section for ReAct prompt."""
|
||||
"""
|
||||
Build compact tools section for ReAct prompt.
|
||||
|
||||
Only includes tool name and brief description for initial scanning.
|
||||
Detailed usage guides are disclosed on-demand when tool is called.
|
||||
"""
|
||||
if not self._tool_registry:
|
||||
return "当前没有可用的工具。"
|
||||
|
||||
tools = self._tool_registry.list_tools(enabled_only=True)
|
||||
tools = self._tool_registry.get_all_tools()
|
||||
if not tools:
|
||||
return "当前没有可用的工具。"
|
||||
|
||||
lines = ["## 可用工具列表", "", "以下是你可以使用的工具,只能使用这些工具:", ""]
|
||||
lines = ["## 可用工具列表", "", "以下是你可以使用的工具:", ""]
|
||||
|
||||
for tool in tools:
|
||||
tool_guide = self._guide_registry.get_tool_guide(tool.name) if self._guide_registry else None
|
||||
description = tool_guide.description if tool_guide else tool.description
|
||||
lines.append(f"- **{tool.name}**: {description}")
|
||||
|
||||
lines.append("")
|
||||
lines.append("调用工具时,系统会提供该工具的详细使用说明。")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _build_tool_usage_guide(self, tool_name: str) -> str:
|
||||
"""
|
||||
Build detailed usage guide for a specific tool.
|
||||
Called when the tool is executed to provide on-demand guidance.
|
||||
"""
|
||||
tool_guide = self._guide_registry.get_tool_guide(tool_name) if self._guide_registry else None
|
||||
tool = self._tool_registry.get_tool(tool_name) if self._tool_registry else None
|
||||
|
||||
if not tool_guide and not tool:
|
||||
return ""
|
||||
|
||||
lines = [f"## {tool_name} 使用说明", ""]
|
||||
|
||||
if tool_guide:
|
||||
lines.append(f"**用途**: {tool_guide.description}")
|
||||
lines.append("")
|
||||
|
||||
if tool_guide.triggers:
|
||||
lines.append("**适用场景**:")
|
||||
for trigger in tool_guide.triggers:
|
||||
lines.append(f"- {trigger}")
|
||||
lines.append("")
|
||||
|
||||
if tool_guide.anti_triggers:
|
||||
lines.append("**不适用场景**:")
|
||||
for anti in tool_guide.anti_triggers:
|
||||
lines.append(f"- {anti}")
|
||||
lines.append("")
|
||||
|
||||
if tool_guide.content:
|
||||
lines.append(tool_guide.content)
|
||||
lines.append("")
|
||||
|
||||
if tool:
|
||||
meta = tool.metadata or {}
|
||||
lines.append(f"### {tool.name}")
|
||||
lines.append(f"用途: {tool.description}")
|
||||
|
||||
when_to_use = meta.get("when_to_use")
|
||||
when_not_to_use = meta.get("when_not_to_use")
|
||||
if when_to_use:
|
||||
lines.append(f"何时使用: {when_to_use}")
|
||||
if when_not_to_use:
|
||||
lines.append(f"何时不要使用: {when_not_to_use}")
|
||||
|
||||
params = meta.get("parameters")
|
||||
if isinstance(params, dict):
|
||||
properties = params.get("properties", {})
|
||||
required = params.get("required", [])
|
||||
if properties:
|
||||
lines.append("参数:")
|
||||
lines.append("**参数说明**:")
|
||||
for param_name, param_info in properties.items():
|
||||
param_desc = param_info.get("description", "") if isinstance(param_info, dict) else ""
|
||||
line = f" - {param_name}: {param_desc}".strip()
|
||||
req_mark = " (必填)" if param_name in required else ""
|
||||
if param_name == "tenant_id":
|
||||
line += " (系统注入,模型不要填写)"
|
||||
elif param_name in required:
|
||||
line += " (必填)"
|
||||
lines.append(line)
|
||||
req_mark = " (系统注入)"
|
||||
lines.append(f"- `{param_name}`: {param_desc}{req_mark}")
|
||||
lines.append("")
|
||||
|
||||
if meta.get("example_action_input"):
|
||||
lines.append("示例入参(JSON):")
|
||||
lines.append("**调用示例**:")
|
||||
try:
|
||||
example_text = json.dumps(meta["example_action_input"], ensure_ascii=False)
|
||||
example_text = json.dumps(meta["example_action_input"], ensure_ascii=False, indent=2)
|
||||
except Exception:
|
||||
example_text = str(meta["example_action_input"])
|
||||
lines.append(example_text)
|
||||
lines.append(f"```json\n{example_text}\n```")
|
||||
lines.append("")
|
||||
|
||||
if meta.get("result_interpretation"):
|
||||
lines.append(f"结果解释: {meta['result_interpretation']}")
|
||||
|
||||
lines.append("")
|
||||
lines.append(f"**结果说明**: {meta['result_interpretation']}")
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def _build_tools_guide_section(self, tool_names: list[str] | None = None) -> str:
|
||||
"""
|
||||
Build detailed tools guide section for ReAct prompt.
|
||||
|
||||
This provides comprehensive usage guides from ToolRegistry.
|
||||
Called separately from _build_tools_section for flexibility.
|
||||
|
||||
Args:
|
||||
tool_names: If provided, only include tools for these names.
|
||||
If None, include all tools.
|
||||
"""
|
||||
return self._guide_registry.build_tools_prompt_section(tool_names)
|
||||
|
||||
def _extract_json_object(self, text: str) -> dict[str, Any] | None:
|
||||
"""Extract the first valid JSON object from free text."""
|
||||
candidates = []
|
||||
|
|
@ -438,6 +971,8 @@ Action Input:
|
|||
) -> tuple[ToolResult, ToolCallTrace]:
|
||||
"""
|
||||
[AC-MARH-07, AC-MARH-08] Execute tool action with timeout.
|
||||
|
||||
On first call to a tool, appends detailed usage guide to observation.
|
||||
"""
|
||||
tool_name = thought.action or "unknown"
|
||||
start_time = time.time()
|
||||
|
|
@ -461,6 +996,23 @@ Action Input:
|
|||
if self._tenant_id:
|
||||
tool_args["tenant_id"] = self._tenant_id
|
||||
|
||||
if self._user_id:
|
||||
tool_args["user_id"] = self._user_id
|
||||
|
||||
if self._session_id:
|
||||
tool_args["session_id"] = self._session_id
|
||||
|
||||
if tool_name == "kb_search_dynamic":
|
||||
# 确保 context 存在,供 AI 传入动态过滤条件
|
||||
if "context" not in tool_args:
|
||||
tool_args["context"] = {}
|
||||
# scene 参数由 AI 从元数据中选择,系统不强制覆盖
|
||||
|
||||
logger.info(
|
||||
f"[AC-MARH-07] Tool call starting: tool={tool_name}, "
|
||||
f"action_input={thought.action_input}, final_args={tool_args}"
|
||||
)
|
||||
|
||||
result = await asyncio.wait_for(
|
||||
self._tool_registry.execute(
|
||||
tool_name=tool_name,
|
||||
|
|
@ -470,9 +1022,25 @@ Action Input:
|
|||
)
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
called_tools = {tc.tool_name for tc in react_ctx.tool_calls}
|
||||
is_first_call = tool_name not in called_tools
|
||||
|
||||
output = result.output
|
||||
if is_first_call and result.success:
|
||||
usage_guide = self._build_tool_usage_guide(tool_name)
|
||||
if usage_guide:
|
||||
if isinstance(output, dict):
|
||||
output = dict(output)
|
||||
output["_tool_guide"] = usage_guide
|
||||
elif isinstance(output, str):
|
||||
output = f"{output}\n\n---\n{usage_guide}"
|
||||
else:
|
||||
output = {"result": output, "_tool_guide": usage_guide}
|
||||
|
||||
return ToolResult(
|
||||
success=result.success,
|
||||
output=result.output,
|
||||
output=output,
|
||||
error=result.error,
|
||||
duration_ms=duration_ms,
|
||||
), ToolCallTrace(
|
||||
|
|
@ -482,6 +1050,8 @@ Action Input:
|
|||
status=ToolCallStatus.OK if result.success else ToolCallStatus.ERROR,
|
||||
args_digest=str(thought.action_input)[:100] if thought.action_input else None,
|
||||
result_digest=str(result.output)[:100] if result.output else None,
|
||||
arguments=thought.action_input,
|
||||
result=output,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
|
|
@ -497,6 +1067,7 @@ Action Input:
|
|||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.TIMEOUT,
|
||||
error_code="TOOL_TIMEOUT",
|
||||
arguments=thought.action_input,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
|
@ -512,6 +1083,7 @@ Action Input:
|
|||
duration_ms=duration_ms,
|
||||
status=ToolCallStatus.ERROR,
|
||||
error_code="TOOL_ERROR",
|
||||
arguments=thought.action_input,
|
||||
)
|
||||
|
||||
async def _force_final_answer(
|
||||
|
|
|
|||
|
|
@ -53,24 +53,35 @@ class RuntimeContext:
|
|||
|
||||
def to_trace_info(self) -> TraceInfo:
|
||||
"""转换为 TraceInfo。"""
|
||||
return TraceInfo(
|
||||
mode=self.mode,
|
||||
intent=self.intent,
|
||||
request_id=self.request_id,
|
||||
generation_id=self.generation_id,
|
||||
guardrail_triggered=self.guardrail_triggered,
|
||||
guardrail_rule_id=self.guardrail_rule_id,
|
||||
interrupt_consumed=self.interrupt_consumed,
|
||||
kb_tool_called=self.kb_tool_called,
|
||||
kb_hit=self.kb_hit,
|
||||
fallback_reason_code=self.fallback_reason_code,
|
||||
react_iterations=self.react_iterations,
|
||||
timeout_profile=self.timeout_profile,
|
||||
segment_stats=self.segment_stats,
|
||||
metrics_snapshot=self.metrics_snapshot,
|
||||
tools_used=[tc.tool_name for tc in self.tool_calls] if self.tool_calls else None,
|
||||
tool_calls=self.tool_calls if self.tool_calls else None,
|
||||
)
|
||||
try:
|
||||
return TraceInfo(
|
||||
mode=self.mode,
|
||||
intent=self.intent,
|
||||
request_id=self.request_id,
|
||||
generation_id=self.generation_id,
|
||||
guardrail_triggered=self.guardrail_triggered,
|
||||
guardrail_rule_id=self.guardrail_rule_id,
|
||||
interrupt_consumed=self.interrupt_consumed,
|
||||
kb_tool_called=self.kb_tool_called,
|
||||
kb_hit=self.kb_hit,
|
||||
fallback_reason_code=self.fallback_reason_code,
|
||||
react_iterations=self.react_iterations,
|
||||
timeout_profile=self.timeout_profile,
|
||||
segment_stats=self.segment_stats,
|
||||
metrics_snapshot=self.metrics_snapshot,
|
||||
tools_used=[tc.tool_name for tc in self.tool_calls] if self.tool_calls else None,
|
||||
tool_calls=self.tool_calls if self.tool_calls else None,
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
logger.error(
|
||||
f"[RuntimeObserver] Failed to create TraceInfo: {e}\n"
|
||||
f"Exception type: {type(e).__name__}\n"
|
||||
f"Context: mode={self.mode}, request_id={self.request_id}, "
|
||||
f"generation_id={self.generation_id}\n"
|
||||
f"Traceback:\n{traceback.format_exc()}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
class RuntimeObserver:
|
||||
|
|
|
|||
|
|
@ -127,7 +127,8 @@ class ToolCallRecorder:
|
|||
logger.info(
|
||||
f"[AC-IDMP-15] Tool call recorded: tool={trace.tool_name}, "
|
||||
f"type={trace.tool_type.value}, duration_ms={trace.duration_ms}, "
|
||||
f"status={trace.status.value}, session={session_id}"
|
||||
f"status={trace.status.value}, session={session_id}, "
|
||||
f"args_digest={trace.args_digest}, result_digest={trace.result_digest}"
|
||||
)
|
||||
|
||||
def record_success(
|
||||
|
|
@ -153,6 +154,8 @@ class ToolCallRecorder:
|
|||
auth_applied=auth_applied,
|
||||
args_digest=ToolCallTrace.compute_digest(args) if args else None,
|
||||
result_digest=ToolCallTrace.compute_digest(result) if result else None,
|
||||
arguments=args if isinstance(args, dict) else None,
|
||||
result=result,
|
||||
)
|
||||
self.record(session_id, trace)
|
||||
return trace
|
||||
|
|
@ -179,6 +182,7 @@ class ToolCallRecorder:
|
|||
registry_version=registry_version,
|
||||
auth_applied=auth_applied,
|
||||
args_digest=ToolCallTrace.compute_digest(args) if args else None,
|
||||
arguments=args if isinstance(args, dict) else None,
|
||||
)
|
||||
self.record(session_id, trace)
|
||||
return trace
|
||||
|
|
@ -207,6 +211,7 @@ class ToolCallRecorder:
|
|||
registry_version=registry_version,
|
||||
auth_applied=auth_applied,
|
||||
args_digest=ToolCallTrace.compute_digest(args) if args else None,
|
||||
arguments=args if isinstance(args, dict) else None,
|
||||
)
|
||||
self.record(session_id, trace)
|
||||
return trace
|
||||
|
|
@ -231,6 +236,7 @@ class ToolCallRecorder:
|
|||
error_code=reason,
|
||||
registry_version=registry_version,
|
||||
args_digest=ToolCallTrace.compute_digest(args) if args else None,
|
||||
arguments=args if isinstance(args, dict) else None,
|
||||
)
|
||||
self.record(session_id, trace)
|
||||
return trace
|
||||
|
|
|
|||
|
|
@ -0,0 +1,111 @@
|
|||
"""
|
||||
Tool definition converter for Function Calling.
|
||||
Converts ToolRegistry definitions to LLM ToolDefinition format.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from app.services.llm.base import ToolDefinition
|
||||
from app.services.mid.tool_registry import ToolDefinition as RegistryToolDefinition
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def convert_tool_to_llm_format(tool: RegistryToolDefinition) -> ToolDefinition:
|
||||
"""
|
||||
Convert ToolRegistry tool definition to LLM ToolDefinition format.
|
||||
|
||||
Args:
|
||||
tool: Tool definition from ToolRegistry
|
||||
|
||||
Returns:
|
||||
ToolDefinition for Function Calling
|
||||
"""
|
||||
meta = tool.metadata or {}
|
||||
parameters = meta.get("parameters", {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
})
|
||||
|
||||
if not isinstance(parameters, dict):
|
||||
parameters = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
if "type" not in parameters:
|
||||
parameters["type"] = "object"
|
||||
if "properties" not in parameters:
|
||||
parameters["properties"] = {}
|
||||
if "required" not in parameters:
|
||||
parameters["required"] = []
|
||||
|
||||
properties = parameters.get("properties", {})
|
||||
if "tenant_id" in properties:
|
||||
properties = {k: v for k, v in properties.items() if k != "tenant_id"}
|
||||
if "user_id" in properties:
|
||||
properties = {k: v for k, v in properties.items() if k != "user_id"}
|
||||
if "session_id" in properties:
|
||||
properties = {k: v for k, v in properties.items() if k != "session_id"}
|
||||
|
||||
parameters["properties"] = properties
|
||||
|
||||
required = parameters.get("required", [])
|
||||
required = [r for r in required if r not in ("tenant_id", "user_id", "session_id")]
|
||||
parameters["required"] = required
|
||||
|
||||
return ToolDefinition(
|
||||
name=tool.name,
|
||||
description=tool.description,
|
||||
parameters=parameters,
|
||||
)
|
||||
|
||||
|
||||
def convert_tools_to_llm_format(tools: list[RegistryToolDefinition]) -> list[ToolDefinition]:
|
||||
"""
|
||||
Convert multiple tool definitions to LLM format.
|
||||
|
||||
Args:
|
||||
tools: List of tool definitions from ToolRegistry
|
||||
|
||||
Returns:
|
||||
List of ToolDefinition for Function Calling
|
||||
"""
|
||||
return [convert_tool_to_llm_format(tool) for tool in tools]
|
||||
|
||||
|
||||
def build_tool_result_message(
|
||||
tool_call_id: str,
|
||||
tool_name: str,
|
||||
result: dict[str, Any],
|
||||
tool_guide: str | None = None,
|
||||
) -> dict[str, str]:
|
||||
"""
|
||||
Build a tool result message for the conversation.
|
||||
|
||||
Args:
|
||||
tool_call_id: ID of the tool call
|
||||
tool_name: Name of the tool
|
||||
result: Tool execution result
|
||||
tool_guide: Optional tool usage guide to append
|
||||
|
||||
Returns:
|
||||
Message dict with role='tool'
|
||||
"""
|
||||
if isinstance(result, dict):
|
||||
result_copy = {k: v for k, v in result.items() if k != "_tool_guide"}
|
||||
content = str(result_copy)
|
||||
else:
|
||||
content = str(result)
|
||||
|
||||
if tool_guide:
|
||||
content = f"{content}\n\n---\n{tool_guide}"
|
||||
|
||||
return {
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": content,
|
||||
}
|
||||
|
|
@ -0,0 +1,313 @@
|
|||
"""
|
||||
Tool Guide Registry for Mid Platform.
|
||||
Provides tool-based usage guidance with caching support.
|
||||
|
||||
Tool guides are usage manuals for tools, loaded on-demand with metadata scanning.
|
||||
This separates tool definitions (Function Calling) from usage guides (Tool Guides).
|
||||
"""
|
||||
|
||||
import logging
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TOOLS_DIR = Path(__file__).parent.parent.parent.parent / "tools"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolGuideMetadata:
|
||||
"""Lightweight tool guide metadata for quick scanning (~100 tokens)."""
|
||||
name: str
|
||||
description: str
|
||||
triggers: list[str] = field(default_factory=list)
|
||||
anti_triggers: list[str] = field(default_factory=list)
|
||||
tools: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolGuideDefinition:
|
||||
"""Full tool guide definition with complete content."""
|
||||
name: str
|
||||
description: str
|
||||
triggers: list[str]
|
||||
anti_triggers: list[str]
|
||||
tools: list[str]
|
||||
content: str
|
||||
raw_markdown: str
|
||||
|
||||
|
||||
class ToolGuideRegistry:
|
||||
"""
|
||||
Tool guide registry with caching support.
|
||||
|
||||
Features:
|
||||
- Load tool guides from .md files
|
||||
- Cache tool guides in memory for high-frequency access
|
||||
- Provide lightweight metadata for quick scanning
|
||||
- Provide full content on demand
|
||||
"""
|
||||
|
||||
_instance = None
|
||||
_initialized = False
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
return cls._instance
|
||||
|
||||
def __init__(self, tools_dir: Path | None = None):
|
||||
if ToolGuideRegistry._initialized:
|
||||
return
|
||||
|
||||
self._tools_dir = tools_dir or TOOLS_DIR
|
||||
self._tool_guides: dict[str, ToolGuideDefinition] = {}
|
||||
self._metadata_cache: dict[str, ToolGuideMetadata] = {}
|
||||
self._tool_to_guide: dict[str, str] = {}
|
||||
self._loaded = False
|
||||
|
||||
ToolGuideRegistry._initialized = True
|
||||
logger.info(f"[ToolGuideRegistry] Initialized with tools_dir={self._tools_dir}")
|
||||
|
||||
def load_tools(self, force_reload: bool = False) -> None:
|
||||
"""
|
||||
Load all tool guides from tools directory into cache.
|
||||
|
||||
Args:
|
||||
force_reload: Force reload even if already loaded
|
||||
"""
|
||||
if self._loaded and not force_reload:
|
||||
logger.debug("[ToolGuideRegistry] Tool guides already loaded, skipping")
|
||||
return
|
||||
|
||||
if not self._tools_dir.exists():
|
||||
logger.warning(f"[ToolGuideRegistry] Tools directory not found: {self._tools_dir}")
|
||||
return
|
||||
|
||||
self._tool_guides.clear()
|
||||
self._metadata_cache.clear()
|
||||
self._tool_to_guide.clear()
|
||||
|
||||
for tool_file in self._tools_dir.glob("*.md"):
|
||||
try:
|
||||
tool_guide = self._parse_tool_file(tool_file)
|
||||
if tool_guide:
|
||||
self._tool_guides[tool_guide.name] = tool_guide
|
||||
self._metadata_cache[tool_guide.name] = ToolGuideMetadata(
|
||||
name=tool_guide.name,
|
||||
description=tool_guide.description,
|
||||
triggers=tool_guide.triggers,
|
||||
anti_triggers=tool_guide.anti_triggers,
|
||||
tools=tool_guide.tools,
|
||||
)
|
||||
for tool_name in tool_guide.tools:
|
||||
self._tool_to_guide[tool_name] = tool_guide.name
|
||||
logger.info(f"[ToolGuideRegistry] Loaded tool guide: {tool_guide.name} (tools: {tool_guide.tools})")
|
||||
except Exception as e:
|
||||
logger.error(f"[ToolGuideRegistry] Failed to load tool guide from {tool_file}: {e}")
|
||||
|
||||
self._loaded = True
|
||||
logger.info(f"[ToolGuideRegistry] Loaded {len(self._tool_guides)} tool guides")
|
||||
|
||||
def _parse_tool_file(self, file_path: Path) -> ToolGuideDefinition | None:
|
||||
"""
|
||||
Parse a tool guide markdown file.
|
||||
|
||||
Expected format:
|
||||
---
|
||||
name: tool_name
|
||||
description: Tool description
|
||||
triggers:
|
||||
- trigger 1
|
||||
- trigger 2
|
||||
anti_triggers:
|
||||
- anti trigger 1
|
||||
tools:
|
||||
- tool_name
|
||||
---
|
||||
|
||||
## Usage Guide
|
||||
...
|
||||
"""
|
||||
content = file_path.read_text(encoding="utf-8")
|
||||
|
||||
frontmatter_match = re.match(r"^---\s*\n(.*?)\n---\s*\n(.*)$", content, re.DOTALL)
|
||||
if not frontmatter_match:
|
||||
logger.warning(f"[ToolGuideRegistry] No frontmatter found in {file_path}")
|
||||
return None
|
||||
|
||||
frontmatter_text = frontmatter_match.group(1)
|
||||
body = frontmatter_match.group(2)
|
||||
|
||||
metadata: dict[str, Any] = {}
|
||||
current_key = None
|
||||
current_list: list[str] | None = None
|
||||
|
||||
for line in frontmatter_text.split("\n"):
|
||||
if not line.strip():
|
||||
continue
|
||||
|
||||
key_match = re.match(r"^(\w+):\s*(.*)$", line)
|
||||
if key_match:
|
||||
current_key = key_match.group(1)
|
||||
value = key_match.group(2).strip()
|
||||
|
||||
if value:
|
||||
metadata[current_key] = value
|
||||
current_list = None
|
||||
else:
|
||||
current_list = []
|
||||
metadata[current_key] = current_list
|
||||
elif line.startswith(" - ") and current_list is not None:
|
||||
current_list.append(line[4:].strip())
|
||||
|
||||
name = metadata.get("name", file_path.stem)
|
||||
description = metadata.get("description", "")
|
||||
triggers = metadata.get("triggers", [])
|
||||
anti_triggers = metadata.get("anti_triggers", [])
|
||||
tools = metadata.get("tools", [])
|
||||
|
||||
if isinstance(triggers, str):
|
||||
triggers = [triggers]
|
||||
if isinstance(anti_triggers, str):
|
||||
anti_triggers = [anti_triggers]
|
||||
if isinstance(tools, str):
|
||||
tools = [tools]
|
||||
|
||||
return ToolGuideDefinition(
|
||||
name=name,
|
||||
description=description,
|
||||
triggers=triggers,
|
||||
anti_triggers=anti_triggers,
|
||||
tools=tools,
|
||||
content=body.strip(),
|
||||
raw_markdown=content,
|
||||
)
|
||||
|
||||
def get_tool_guide(self, name: str) -> ToolGuideDefinition | None:
|
||||
"""Get full tool guide definition by name."""
|
||||
if not self._loaded:
|
||||
self.load_tools()
|
||||
return self._tool_guides.get(name)
|
||||
|
||||
def get_tool_metadata(self, name: str) -> ToolGuideMetadata | None:
|
||||
"""Get lightweight tool guide metadata by name."""
|
||||
if not self._loaded:
|
||||
self.load_tools()
|
||||
return self._metadata_cache.get(name)
|
||||
|
||||
def get_guide_for_tool(self, tool_name: str) -> ToolGuideDefinition | None:
|
||||
"""Get tool guide associated with a tool."""
|
||||
if not self._loaded:
|
||||
self.load_tools()
|
||||
guide_name = self._tool_to_guide.get(tool_name)
|
||||
if guide_name:
|
||||
return self._tool_guides.get(guide_name)
|
||||
return None
|
||||
|
||||
def list_tools(self) -> list[str]:
|
||||
"""List all tool guide names."""
|
||||
if not self._loaded:
|
||||
self.load_tools()
|
||||
return list(self._tool_guides.keys())
|
||||
|
||||
def list_tool_metadata(self) -> list[ToolGuideMetadata]:
|
||||
"""List all tool guide metadata (lightweight)."""
|
||||
if not self._loaded:
|
||||
self.load_tools()
|
||||
return list(self._metadata_cache.values())
|
||||
|
||||
def build_tools_prompt_section(self, tool_names: list[str] | None = None) -> str:
|
||||
"""
|
||||
Build tools section for ReAct prompt.
|
||||
|
||||
Args:
|
||||
tool_names: If provided, only include tools for these names.
|
||||
If None, include all tools.
|
||||
|
||||
Returns:
|
||||
Formatted tools section string
|
||||
"""
|
||||
if not self._loaded:
|
||||
self.load_tools()
|
||||
|
||||
if not self._tool_guides:
|
||||
return ""
|
||||
|
||||
tools_to_include: list[ToolGuideDefinition] = []
|
||||
|
||||
if tool_names:
|
||||
for tool_name in tool_names:
|
||||
tool_guide = self.get_guide_for_tool(tool_name)
|
||||
if tool_guide and tool_guide not in tools_to_include:
|
||||
tools_to_include.append(tool_guide)
|
||||
else:
|
||||
tools_to_include = list(self._tool_guides.values())
|
||||
|
||||
if not tools_to_include:
|
||||
return ""
|
||||
|
||||
lines = ["## 工具使用指南", ""]
|
||||
lines.append("以下是每个工具的详细使用说明:")
|
||||
lines.append("")
|
||||
|
||||
for tool_guide in tools_to_include:
|
||||
lines.append(f"### {tool_guide.name}")
|
||||
lines.append(f"描述: {tool_guide.description}")
|
||||
|
||||
if tool_guide.triggers:
|
||||
lines.append("触发条件:")
|
||||
for trigger in tool_guide.triggers:
|
||||
lines.append(f" - {trigger}")
|
||||
|
||||
if tool_guide.anti_triggers:
|
||||
lines.append("不应触发:")
|
||||
for anti in tool_guide.anti_triggers:
|
||||
lines.append(f" - {anti}")
|
||||
|
||||
lines.append("")
|
||||
lines.append(tool_guide.content)
|
||||
lines.append("")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
def build_compact_tools_section(self) -> str:
|
||||
"""
|
||||
Build compact tools section with only name and description.
|
||||
This is used for the initial tool list, with full guidance loaded separately.
|
||||
"""
|
||||
if not self._loaded:
|
||||
self.load_tools()
|
||||
|
||||
if not self._metadata_cache:
|
||||
return "当前没有可用的工具使用指南。"
|
||||
|
||||
lines = ["## 可用工具列表", "", "以下是你可以使用的工具:", ""]
|
||||
|
||||
for meta in self._metadata_cache.values():
|
||||
lines.append(f"- **{meta.name}**: {meta.description}")
|
||||
if meta.tools:
|
||||
lines.append(f" 关联工具: {', '.join(meta.tools)}")
|
||||
|
||||
return "\n".join(lines)
|
||||
|
||||
|
||||
_tool_guide_registry: ToolGuideRegistry | None = None
|
||||
|
||||
|
||||
def get_tool_guide_registry() -> ToolGuideRegistry:
|
||||
"""Get global tool guide registry instance."""
|
||||
global _tool_guide_registry
|
||||
if _tool_guide_registry is None:
|
||||
_tool_guide_registry = ToolGuideRegistry()
|
||||
return _tool_guide_registry
|
||||
|
||||
|
||||
def init_tool_guide_registry(tools_dir: Path | None = None) -> ToolGuideRegistry:
|
||||
"""Initialize and return tool guide registry."""
|
||||
global _tool_guide_registry
|
||||
_tool_guide_registry = ToolGuideRegistry(tools_dir=tools_dir)
|
||||
_tool_guide_registry.load_tools()
|
||||
return _tool_guide_registry
|
||||
|
|
@ -117,172 +117,124 @@ class ToolRegistry:
|
|||
self._tools[name] = tool
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDMP-19] Tool registered: name={name}, type={tool_type.value}, "
|
||||
f"version={version}, auth_required={auth_required}"
|
||||
f"[AC-IDMP-19] Registered tool: {name} v{version} "
|
||||
f"(type={tool_type.value}, auth={auth_required}, timeout={timeout_ms}ms)"
|
||||
)
|
||||
|
||||
return tool
|
||||
|
||||
def unregister(self, name: str) -> bool:
|
||||
"""Unregister a tool."""
|
||||
if name in self._tools:
|
||||
del self._tools[name]
|
||||
logger.info(f"[AC-IDMP-19] Tool unregistered: {name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def get_tool(self, name: str) -> ToolDefinition | None:
|
||||
"""Get tool definition by name."""
|
||||
"""Get tool by name."""
|
||||
return self._tools.get(name)
|
||||
|
||||
def list_tools(
|
||||
self,
|
||||
tool_type: ToolType | None = None,
|
||||
enabled_only: bool = True,
|
||||
) -> list[ToolDefinition]:
|
||||
"""List registered tools, optionally filtered."""
|
||||
tools = list(self._tools.values())
|
||||
def list_tools(self) -> list[str]:
|
||||
"""List all registered tool names."""
|
||||
return list(self._tools.keys())
|
||||
|
||||
if tool_type:
|
||||
tools = [t for t in tools if t.tool_type == tool_type]
|
||||
def get_all_tools(self) -> list[ToolDefinition]:
|
||||
"""Get all registered tools."""
|
||||
return list(self._tools.values())
|
||||
|
||||
if enabled_only:
|
||||
tools = [t for t in tools if t.enabled]
|
||||
def is_enabled(self, name: str) -> bool:
|
||||
"""Check if tool is enabled."""
|
||||
tool = self._tools.get(name)
|
||||
return tool.enabled if tool else False
|
||||
|
||||
return tools
|
||||
|
||||
def enable_tool(self, name: str) -> bool:
|
||||
"""Enable a tool."""
|
||||
def set_enabled(self, name: str, enabled: bool) -> bool:
|
||||
"""Enable or disable a tool."""
|
||||
tool = self._tools.get(name)
|
||||
if tool:
|
||||
tool.enabled = True
|
||||
logger.info(f"[AC-IDMP-19] Tool enabled: {name}")
|
||||
return True
|
||||
return False
|
||||
|
||||
def disable_tool(self, name: str) -> bool:
|
||||
"""Disable a tool."""
|
||||
tool = self._tools.get(name)
|
||||
if tool:
|
||||
tool.enabled = False
|
||||
logger.info(f"[AC-IDMP-19] Tool disabled: {name}")
|
||||
tool.enabled = enabled
|
||||
logger.info(f"[AC-IDMP-19] Tool {name} {'enabled' if enabled else 'disabled'}")
|
||||
return True
|
||||
return False
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
tool_name: str,
|
||||
args: dict[str, Any],
|
||||
auth_context: dict[str, Any] | None = None,
|
||||
name: str,
|
||||
**kwargs: Any,
|
||||
) -> ToolExecutionResult:
|
||||
"""
|
||||
[AC-IDMP-19] Execute a tool with governance.
|
||||
Execute a tool with governance.
|
||||
|
||||
Args:
|
||||
tool_name: Tool name to execute
|
||||
args: Tool arguments
|
||||
auth_context: Authentication context
|
||||
name: Tool name
|
||||
**kwargs: Tool arguments
|
||||
|
||||
Returns:
|
||||
ToolExecutionResult with output and metadata
|
||||
"""
|
||||
start_time = time.time()
|
||||
|
||||
tool = self._tools.get(tool_name)
|
||||
tool = self._tools.get(name)
|
||||
if not tool:
|
||||
logger.warning(f"[AC-IDMP-19] Tool not found: {tool_name}")
|
||||
return ToolExecutionResult(
|
||||
success=False,
|
||||
error=f"Tool not found: {tool_name}",
|
||||
duration_ms=0,
|
||||
error=f"Tool not found: {name}",
|
||||
registry_version=self._version,
|
||||
)
|
||||
|
||||
if not tool.enabled:
|
||||
logger.warning(f"[AC-IDMP-19] Tool disabled: {tool_name}")
|
||||
return ToolExecutionResult(
|
||||
success=False,
|
||||
error=f"Tool disabled: {tool_name}",
|
||||
duration_ms=0,
|
||||
registry_version=tool.version,
|
||||
error=f"Tool is disabled: {name}",
|
||||
registry_version=self._version,
|
||||
)
|
||||
|
||||
auth_applied = False
|
||||
if tool.auth_required:
|
||||
if not auth_context:
|
||||
logger.warning(f"[AC-IDMP-19] Auth required but no context: {tool_name}")
|
||||
return ToolExecutionResult(
|
||||
success=False,
|
||||
error="Authentication required",
|
||||
duration_ms=int((time.time() - start_time) * 1000),
|
||||
auth_applied=False,
|
||||
registry_version=tool.version,
|
||||
)
|
||||
auth_applied = True
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
timeout_seconds = tool.timeout_ms / 1000.0
|
||||
if not tool.handler:
|
||||
return ToolExecutionResult(
|
||||
success=False,
|
||||
error=f"Tool has no handler: {name}",
|
||||
registry_version=self._version,
|
||||
)
|
||||
|
||||
result = await asyncio.wait_for(
|
||||
tool.handler(**args) if tool.handler else asyncio.sleep(0),
|
||||
timeout=timeout_seconds,
|
||||
result = await self._timeout_governor.execute_with_timeout(
|
||||
lambda: tool.handler(**kwargs),
|
||||
timeout_ms=tool.timeout_ms,
|
||||
)
|
||||
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
|
||||
logger.info(
|
||||
f"[AC-IDMP-19] Tool executed: name={tool_name}, "
|
||||
f"duration_ms={duration_ms}, success=True"
|
||||
)
|
||||
|
||||
return ToolExecutionResult(
|
||||
success=True,
|
||||
output=result,
|
||||
duration_ms=duration_ms,
|
||||
auth_applied=auth_applied,
|
||||
registry_version=tool.version,
|
||||
auth_applied=tool.auth_required,
|
||||
registry_version=self._version,
|
||||
)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.warning(
|
||||
f"[AC-IDMP-19] Tool timeout: name={tool_name}, "
|
||||
f"duration_ms={duration_ms}"
|
||||
)
|
||||
return ToolExecutionResult(
|
||||
success=False,
|
||||
error=f"Tool timeout after {tool.timeout_ms}ms",
|
||||
error=f"Tool execution timeout after {tool.timeout_ms}ms",
|
||||
duration_ms=duration_ms,
|
||||
auth_applied=auth_applied,
|
||||
registry_version=tool.version,
|
||||
registry_version=self._version,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
duration_ms = int((time.time() - start_time) * 1000)
|
||||
logger.error(
|
||||
f"[AC-IDMP-19] Tool error: name={tool_name}, error={e}"
|
||||
)
|
||||
logger.error(f"[AC-IDMP-19] Tool execution error: {name} - {e}")
|
||||
return ToolExecutionResult(
|
||||
success=False,
|
||||
error=str(e),
|
||||
duration_ms=duration_ms,
|
||||
auth_applied=auth_applied,
|
||||
registry_version=tool.version,
|
||||
registry_version=self._version,
|
||||
)
|
||||
|
||||
def create_trace(
|
||||
def build_trace(
|
||||
self,
|
||||
tool_name: str,
|
||||
args: dict[str, Any],
|
||||
result: ToolExecutionResult,
|
||||
args_digest: str | None = None,
|
||||
) -> ToolCallTrace:
|
||||
"""
|
||||
[AC-IDMP-19] Create ToolCallTrace from execution result.
|
||||
"""
|
||||
tool = self._tools.get(tool_name)
|
||||
"""Build a tool call trace from execution result."""
|
||||
import hashlib
|
||||
args_digest = hashlib.md5(str(args).encode()).hexdigest()[:8]
|
||||
|
||||
return ToolCallTrace(
|
||||
tool_name=tool_name,
|
||||
tool_type=tool.tool_type if tool else ToolType.INTERNAL,
|
||||
tool_type=tool.tool_type if (tool := self._tools.get(tool_name)) else ToolType.INTERNAL,
|
||||
registry_version=result.registry_version,
|
||||
auth_applied=result.auth_applied,
|
||||
duration_ms=result.duration_ms,
|
||||
|
|
@ -293,6 +245,8 @@ class ToolRegistry:
|
|||
error_code=result.error if not result.success else None,
|
||||
args_digest=args_digest,
|
||||
result_digest=str(result.output)[:100] if result.output else None,
|
||||
arguments=args,
|
||||
result=result.output,
|
||||
)
|
||||
|
||||
def get_governance_report(self) -> dict[str, Any]:
|
||||
|
|
|
|||
Loading…
Reference in New Issue