292 lines
9.2 KiB
Python
292 lines
9.2 KiB
Python
"""
|
|
Tool Registry for Mid Platform.
|
|
[AC-IDMP-19] Unified tool registration, auth, timeout, version, and enable/disable governance.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import time
|
|
import uuid
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Callable, Coroutine
|
|
|
|
from app.models.mid.schemas import (
|
|
ToolCallStatus,
|
|
ToolCallTrace,
|
|
ToolType,
|
|
)
|
|
from app.services.mid.timeout_governor import TimeoutGovernor
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
@dataclass
|
|
class ToolDefinition:
|
|
"""Tool definition for registry."""
|
|
name: str
|
|
description: str
|
|
tool_type: ToolType = ToolType.INTERNAL
|
|
version: str = "1.0.0"
|
|
enabled: bool = True
|
|
auth_required: bool = False
|
|
timeout_ms: int = 2000
|
|
handler: Callable[..., Coroutine[Any, Any, dict[str, Any]]] | None = None
|
|
metadata: dict[str, Any] = field(default_factory=dict)
|
|
|
|
|
|
@dataclass
|
|
class ToolExecutionResult:
|
|
"""Tool execution result."""
|
|
success: bool
|
|
output: Any = None
|
|
error: str | None = None
|
|
duration_ms: int = 0
|
|
auth_applied: bool = False
|
|
registry_version: str | None = None
|
|
|
|
|
|
class ToolRegistry:
|
|
"""
|
|
[AC-IDMP-19] Unified tool registry for governance.
|
|
|
|
Features:
|
|
- Tool registration with metadata
|
|
- Auth policy enforcement
|
|
- Timeout governance
|
|
- Version management
|
|
- Enable/disable control
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
timeout_governor: TimeoutGovernor | None = None,
|
|
):
|
|
self._tools: dict[str, ToolDefinition] = {}
|
|
self._timeout_governor = timeout_governor or TimeoutGovernor()
|
|
self._version = "1.0.0"
|
|
|
|
@property
|
|
def version(self) -> str:
|
|
"""Get registry version."""
|
|
return self._version
|
|
|
|
def register(
|
|
self,
|
|
name: str,
|
|
description: str,
|
|
handler: Callable[..., Coroutine[Any, Any, dict[str, Any]]],
|
|
tool_type: ToolType = ToolType.INTERNAL,
|
|
version: str = "1.0.0",
|
|
auth_required: bool = False,
|
|
timeout_ms: int = 2000,
|
|
enabled: bool = True,
|
|
metadata: dict[str, Any] | None = None,
|
|
) -> ToolDefinition:
|
|
"""
|
|
[AC-IDMP-19] Register a tool.
|
|
|
|
Args:
|
|
name: Tool name (unique identifier)
|
|
description: Tool description
|
|
handler: Async handler function
|
|
tool_type: Tool type (internal/mcp)
|
|
version: Tool version
|
|
auth_required: Whether auth is required
|
|
timeout_ms: Tool-specific timeout
|
|
enabled: Whether tool is enabled
|
|
metadata: Additional metadata
|
|
|
|
Returns:
|
|
ToolDefinition for the registered tool
|
|
"""
|
|
if name in self._tools:
|
|
logger.warning(f"[AC-IDMP-19] Tool already registered, overwriting: {name}")
|
|
|
|
tool = ToolDefinition(
|
|
name=name,
|
|
description=description,
|
|
tool_type=tool_type,
|
|
version=version,
|
|
enabled=enabled,
|
|
auth_required=auth_required,
|
|
timeout_ms=min(timeout_ms, 2000),
|
|
handler=handler,
|
|
metadata=metadata or {},
|
|
)
|
|
|
|
self._tools[name] = tool
|
|
|
|
logger.info(
|
|
f"[AC-IDMP-19] Registered tool: {name} v{version} "
|
|
f"(type={tool_type.value}, auth={auth_required}, timeout={timeout_ms}ms)"
|
|
)
|
|
|
|
return tool
|
|
|
|
def get_tool(self, name: str) -> ToolDefinition | None:
|
|
"""Get tool by name."""
|
|
return self._tools.get(name)
|
|
|
|
def list_tools(self) -> list[str]:
|
|
"""List all registered tool names."""
|
|
return list(self._tools.keys())
|
|
|
|
def get_all_tools(self) -> list[ToolDefinition]:
|
|
"""Get all registered tools."""
|
|
return list(self._tools.values())
|
|
|
|
def is_enabled(self, name: str) -> bool:
|
|
"""Check if tool is enabled."""
|
|
tool = self._tools.get(name)
|
|
return tool.enabled if tool else False
|
|
|
|
def set_enabled(self, name: str, enabled: bool) -> bool:
|
|
"""Enable or disable a tool."""
|
|
tool = self._tools.get(name)
|
|
if tool:
|
|
tool.enabled = enabled
|
|
logger.info(f"[AC-IDMP-19] Tool {name} {'enabled' if enabled else 'disabled'}")
|
|
return True
|
|
return False
|
|
|
|
async def execute(
|
|
self,
|
|
name: str,
|
|
**kwargs: Any,
|
|
) -> ToolExecutionResult:
|
|
"""
|
|
Execute a tool with governance.
|
|
|
|
Args:
|
|
name: Tool name
|
|
**kwargs: Tool arguments
|
|
|
|
Returns:
|
|
ToolExecutionResult with output and metadata
|
|
"""
|
|
tool = self._tools.get(name)
|
|
if not tool:
|
|
return ToolExecutionResult(
|
|
success=False,
|
|
error=f"Tool not found: {name}",
|
|
registry_version=self._version,
|
|
)
|
|
|
|
if not tool.enabled:
|
|
return ToolExecutionResult(
|
|
success=False,
|
|
error=f"Tool is disabled: {name}",
|
|
registry_version=self._version,
|
|
)
|
|
|
|
start_time = time.time()
|
|
|
|
try:
|
|
if not tool.handler:
|
|
return ToolExecutionResult(
|
|
success=False,
|
|
error=f"Tool has no handler: {name}",
|
|
registry_version=self._version,
|
|
)
|
|
|
|
result = await self._timeout_governor.execute_with_timeout(
|
|
lambda: tool.handler(**kwargs),
|
|
timeout_ms=tool.timeout_ms,
|
|
)
|
|
|
|
duration_ms = int((time.time() - start_time) * 1000)
|
|
|
|
return ToolExecutionResult(
|
|
success=True,
|
|
output=result,
|
|
duration_ms=duration_ms,
|
|
auth_applied=tool.auth_required,
|
|
registry_version=self._version,
|
|
)
|
|
|
|
except asyncio.TimeoutError:
|
|
duration_ms = int((time.time() - start_time) * 1000)
|
|
return ToolExecutionResult(
|
|
success=False,
|
|
error=f"Tool execution timeout after {tool.timeout_ms}ms",
|
|
duration_ms=duration_ms,
|
|
registry_version=self._version,
|
|
)
|
|
except Exception as e:
|
|
duration_ms = int((time.time() - start_time) * 1000)
|
|
logger.error(f"[AC-IDMP-19] Tool execution error: {name} - {e}")
|
|
return ToolExecutionResult(
|
|
success=False,
|
|
error=str(e),
|
|
duration_ms=duration_ms,
|
|
registry_version=self._version,
|
|
)
|
|
|
|
def build_trace(
|
|
self,
|
|
tool_name: str,
|
|
args: dict[str, Any],
|
|
result: ToolExecutionResult,
|
|
) -> ToolCallTrace:
|
|
"""Build a tool call trace from execution result."""
|
|
import hashlib
|
|
args_digest = hashlib.md5(str(args).encode()).hexdigest()[:8]
|
|
|
|
return ToolCallTrace(
|
|
tool_name=tool_name,
|
|
tool_type=tool.tool_type if (tool := self._tools.get(tool_name)) else ToolType.INTERNAL,
|
|
registry_version=result.registry_version,
|
|
auth_applied=result.auth_applied,
|
|
duration_ms=result.duration_ms,
|
|
status=ToolCallStatus.OK if result.success else (
|
|
ToolCallStatus.TIMEOUT if "timeout" in (result.error or "").lower()
|
|
else ToolCallStatus.ERROR
|
|
),
|
|
error_code=result.error if not result.success else None,
|
|
args_digest=args_digest,
|
|
result_digest=str(result.output)[:100] if result.output else None,
|
|
arguments=args,
|
|
result=result.output,
|
|
)
|
|
|
|
def get_governance_report(self) -> dict[str, Any]:
|
|
"""Get governance report for all tools."""
|
|
return {
|
|
"registry_version": self._version,
|
|
"total_tools": len(self._tools),
|
|
"enabled_tools": sum(1 for t in self._tools.values() if t.enabled),
|
|
"disabled_tools": sum(1 for t in self._tools.values() if not t.enabled),
|
|
"auth_required_tools": sum(1 for t in self._tools.values() if t.auth_required),
|
|
"mcp_tools": sum(1 for t in self._tools.values() if t.tool_type == ToolType.MCP),
|
|
"internal_tools": sum(1 for t in self._tools.values() if t.tool_type == ToolType.INTERNAL),
|
|
"tools": [
|
|
{
|
|
"name": t.name,
|
|
"type": t.tool_type.value,
|
|
"version": t.version,
|
|
"enabled": t.enabled,
|
|
"auth_required": t.auth_required,
|
|
"timeout_ms": t.timeout_ms,
|
|
}
|
|
for t in self._tools.values()
|
|
],
|
|
}
|
|
|
|
|
|
_registry: ToolRegistry | None = None
|
|
|
|
|
|
def get_tool_registry() -> ToolRegistry:
|
|
"""Get global tool registry instance."""
|
|
global _registry
|
|
if _registry is None:
|
|
_registry = ToolRegistry()
|
|
return _registry
|
|
|
|
|
|
def init_tool_registry(timeout_governor: TimeoutGovernor | None = None) -> ToolRegistry:
|
|
"""Initialize and return tool registry."""
|
|
global _registry
|
|
_registry = ToolRegistry(timeout_governor=timeout_governor)
|
|
return _registry
|