ai-robot-core/ai-service/app/services/mid/tool_registry.py

292 lines
9.2 KiB
Python
Raw Normal View History

"""
Tool Registry for Mid Platform.
[AC-IDMP-19] Unified tool registration, auth, timeout, version, and enable/disable governance.
"""
import asyncio
import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, Callable, Coroutine
from app.models.mid.schemas import (
ToolCallStatus,
ToolCallTrace,
ToolType,
)
from app.services.mid.timeout_governor import TimeoutGovernor
logger = logging.getLogger(__name__)
@dataclass
class ToolDefinition:
"""Tool definition for registry."""
name: str
description: str
tool_type: ToolType = ToolType.INTERNAL
version: str = "1.0.0"
enabled: bool = True
auth_required: bool = False
timeout_ms: int = 2000
handler: Callable[..., Coroutine[Any, Any, dict[str, Any]]] | None = None
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class ToolExecutionResult:
"""Tool execution result."""
success: bool
output: Any = None
error: str | None = None
duration_ms: int = 0
auth_applied: bool = False
registry_version: str | None = None
class ToolRegistry:
"""
[AC-IDMP-19] Unified tool registry for governance.
Features:
- Tool registration with metadata
- Auth policy enforcement
- Timeout governance
- Version management
- Enable/disable control
"""
def __init__(
self,
timeout_governor: TimeoutGovernor | None = None,
):
self._tools: dict[str, ToolDefinition] = {}
self._timeout_governor = timeout_governor or TimeoutGovernor()
self._version = "1.0.0"
@property
def version(self) -> str:
"""Get registry version."""
return self._version
def register(
self,
name: str,
description: str,
handler: Callable[..., Coroutine[Any, Any, dict[str, Any]]],
tool_type: ToolType = ToolType.INTERNAL,
version: str = "1.0.0",
auth_required: bool = False,
timeout_ms: int = 2000,
enabled: bool = True,
metadata: dict[str, Any] | None = None,
) -> ToolDefinition:
"""
[AC-IDMP-19] Register a tool.
Args:
name: Tool name (unique identifier)
description: Tool description
handler: Async handler function
tool_type: Tool type (internal/mcp)
version: Tool version
auth_required: Whether auth is required
timeout_ms: Tool-specific timeout
enabled: Whether tool is enabled
metadata: Additional metadata
Returns:
ToolDefinition for the registered tool
"""
if name in self._tools:
logger.warning(f"[AC-IDMP-19] Tool already registered, overwriting: {name}")
tool = ToolDefinition(
name=name,
description=description,
tool_type=tool_type,
version=version,
enabled=enabled,
auth_required=auth_required,
timeout_ms=min(timeout_ms, 2000),
handler=handler,
metadata=metadata or {},
)
self._tools[name] = tool
logger.info(
f"[AC-IDMP-19] Registered tool: {name} v{version} "
f"(type={tool_type.value}, auth={auth_required}, timeout={timeout_ms}ms)"
)
return tool
def get_tool(self, name: str) -> ToolDefinition | None:
"""Get tool by name."""
return self._tools.get(name)
def list_tools(self) -> list[str]:
"""List all registered tool names."""
return list(self._tools.keys())
def get_all_tools(self) -> list[ToolDefinition]:
"""Get all registered tools."""
return list(self._tools.values())
def is_enabled(self, name: str) -> bool:
"""Check if tool is enabled."""
tool = self._tools.get(name)
return tool.enabled if tool else False
def set_enabled(self, name: str, enabled: bool) -> bool:
"""Enable or disable a tool."""
tool = self._tools.get(name)
if tool:
tool.enabled = enabled
logger.info(f"[AC-IDMP-19] Tool {name} {'enabled' if enabled else 'disabled'}")
return True
return False
async def execute(
self,
name: str,
**kwargs: Any,
) -> ToolExecutionResult:
"""
Execute a tool with governance.
Args:
name: Tool name
**kwargs: Tool arguments
Returns:
ToolExecutionResult with output and metadata
"""
tool = self._tools.get(name)
if not tool:
return ToolExecutionResult(
success=False,
error=f"Tool not found: {name}",
registry_version=self._version,
)
if not tool.enabled:
return ToolExecutionResult(
success=False,
error=f"Tool is disabled: {name}",
registry_version=self._version,
)
start_time = time.time()
try:
if not tool.handler:
return ToolExecutionResult(
success=False,
error=f"Tool has no handler: {name}",
registry_version=self._version,
)
result = await self._timeout_governor.execute_with_timeout(
lambda: tool.handler(**kwargs),
timeout_ms=tool.timeout_ms,
)
duration_ms = int((time.time() - start_time) * 1000)
return ToolExecutionResult(
success=True,
output=result,
duration_ms=duration_ms,
auth_applied=tool.auth_required,
registry_version=self._version,
)
except asyncio.TimeoutError:
duration_ms = int((time.time() - start_time) * 1000)
return ToolExecutionResult(
success=False,
error=f"Tool execution timeout after {tool.timeout_ms}ms",
duration_ms=duration_ms,
registry_version=self._version,
)
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
logger.error(f"[AC-IDMP-19] Tool execution error: {name} - {e}")
return ToolExecutionResult(
success=False,
error=str(e),
duration_ms=duration_ms,
registry_version=self._version,
)
def build_trace(
self,
tool_name: str,
args: dict[str, Any],
result: ToolExecutionResult,
) -> ToolCallTrace:
"""Build a tool call trace from execution result."""
import hashlib
args_digest = hashlib.md5(str(args).encode()).hexdigest()[:8]
return ToolCallTrace(
tool_name=tool_name,
tool_type=tool.tool_type if (tool := self._tools.get(tool_name)) else ToolType.INTERNAL,
registry_version=result.registry_version,
auth_applied=result.auth_applied,
duration_ms=result.duration_ms,
status=ToolCallStatus.OK if result.success else (
ToolCallStatus.TIMEOUT if "timeout" in (result.error or "").lower()
else ToolCallStatus.ERROR
),
error_code=result.error if not result.success else None,
args_digest=args_digest,
result_digest=str(result.output)[:100] if result.output else None,
arguments=args,
result=result.output,
)
def get_governance_report(self) -> dict[str, Any]:
"""Get governance report for all tools."""
return {
"registry_version": self._version,
"total_tools": len(self._tools),
"enabled_tools": sum(1 for t in self._tools.values() if t.enabled),
"disabled_tools": sum(1 for t in self._tools.values() if not t.enabled),
"auth_required_tools": sum(1 for t in self._tools.values() if t.auth_required),
"mcp_tools": sum(1 for t in self._tools.values() if t.tool_type == ToolType.MCP),
"internal_tools": sum(1 for t in self._tools.values() if t.tool_type == ToolType.INTERNAL),
"tools": [
{
"name": t.name,
"type": t.tool_type.value,
"version": t.version,
"enabled": t.enabled,
"auth_required": t.auth_required,
"timeout_ms": t.timeout_ms,
}
for t in self._tools.values()
],
}
_registry: ToolRegistry | None = None
def get_tool_registry() -> ToolRegistry:
"""Get global tool registry instance."""
global _registry
if _registry is None:
_registry = ToolRegistry()
return _registry
def init_tool_registry(timeout_governor: TimeoutGovernor | None = None) -> ToolRegistry:
"""Initialize and return tool registry."""
global _registry
_registry = ToolRegistry(timeout_governor=timeout_governor)
return _registry