ai-robot-core/ai-service/app/services/mid/tool_registry.py

338 lines
11 KiB
Python

"""
Tool Registry for Mid Platform.
[AC-IDMP-19] Unified tool registration, auth, timeout, version, and enable/disable governance.
"""
import asyncio
import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import Any, Callable, Coroutine
from app.models.mid.schemas import (
ToolCallStatus,
ToolCallTrace,
ToolType,
)
from app.services.mid.timeout_governor import TimeoutGovernor
logger = logging.getLogger(__name__)
@dataclass
class ToolDefinition:
"""Tool definition for registry."""
name: str
description: str
tool_type: ToolType = ToolType.INTERNAL
version: str = "1.0.0"
enabled: bool = True
auth_required: bool = False
timeout_ms: int = 2000
handler: Callable[..., Coroutine[Any, Any, dict[str, Any]]] | None = None
metadata: dict[str, Any] = field(default_factory=dict)
@dataclass
class ToolExecutionResult:
"""Tool execution result."""
success: bool
output: Any = None
error: str | None = None
duration_ms: int = 0
auth_applied: bool = False
registry_version: str | None = None
class ToolRegistry:
"""
[AC-IDMP-19] Unified tool registry for governance.
Features:
- Tool registration with metadata
- Auth policy enforcement
- Timeout governance
- Version management
- Enable/disable control
"""
def __init__(
self,
timeout_governor: TimeoutGovernor | None = None,
):
self._tools: dict[str, ToolDefinition] = {}
self._timeout_governor = timeout_governor or TimeoutGovernor()
self._version = "1.0.0"
@property
def version(self) -> str:
"""Get registry version."""
return self._version
def register(
self,
name: str,
description: str,
handler: Callable[..., Coroutine[Any, Any, dict[str, Any]]],
tool_type: ToolType = ToolType.INTERNAL,
version: str = "1.0.0",
auth_required: bool = False,
timeout_ms: int = 2000,
enabled: bool = True,
metadata: dict[str, Any] | None = None,
) -> ToolDefinition:
"""
[AC-IDMP-19] Register a tool.
Args:
name: Tool name (unique identifier)
description: Tool description
handler: Async handler function
tool_type: Tool type (internal/mcp)
version: Tool version
auth_required: Whether auth is required
timeout_ms: Tool-specific timeout
enabled: Whether tool is enabled
metadata: Additional metadata
Returns:
ToolDefinition for the registered tool
"""
if name in self._tools:
logger.warning(f"[AC-IDMP-19] Tool already registered, overwriting: {name}")
tool = ToolDefinition(
name=name,
description=description,
tool_type=tool_type,
version=version,
enabled=enabled,
auth_required=auth_required,
timeout_ms=min(timeout_ms, 2000),
handler=handler,
metadata=metadata or {},
)
self._tools[name] = tool
logger.info(
f"[AC-IDMP-19] Tool registered: name={name}, type={tool_type.value}, "
f"version={version}, auth_required={auth_required}"
)
return tool
def unregister(self, name: str) -> bool:
"""Unregister a tool."""
if name in self._tools:
del self._tools[name]
logger.info(f"[AC-IDMP-19] Tool unregistered: {name}")
return True
return False
def get_tool(self, name: str) -> ToolDefinition | None:
"""Get tool definition by name."""
return self._tools.get(name)
def list_tools(
self,
tool_type: ToolType | None = None,
enabled_only: bool = True,
) -> list[ToolDefinition]:
"""List registered tools, optionally filtered."""
tools = list(self._tools.values())
if tool_type:
tools = [t for t in tools if t.tool_type == tool_type]
if enabled_only:
tools = [t for t in tools if t.enabled]
return tools
def enable_tool(self, name: str) -> bool:
"""Enable a tool."""
tool = self._tools.get(name)
if tool:
tool.enabled = True
logger.info(f"[AC-IDMP-19] Tool enabled: {name}")
return True
return False
def disable_tool(self, name: str) -> bool:
"""Disable a tool."""
tool = self._tools.get(name)
if tool:
tool.enabled = False
logger.info(f"[AC-IDMP-19] Tool disabled: {name}")
return True
return False
async def execute(
self,
tool_name: str,
args: dict[str, Any],
auth_context: dict[str, Any] | None = None,
) -> ToolExecutionResult:
"""
[AC-IDMP-19] Execute a tool with governance.
Args:
tool_name: Tool name to execute
args: Tool arguments
auth_context: Authentication context
Returns:
ToolExecutionResult with output and metadata
"""
start_time = time.time()
tool = self._tools.get(tool_name)
if not tool:
logger.warning(f"[AC-IDMP-19] Tool not found: {tool_name}")
return ToolExecutionResult(
success=False,
error=f"Tool not found: {tool_name}",
duration_ms=0,
)
if not tool.enabled:
logger.warning(f"[AC-IDMP-19] Tool disabled: {tool_name}")
return ToolExecutionResult(
success=False,
error=f"Tool disabled: {tool_name}",
duration_ms=0,
registry_version=tool.version,
)
auth_applied = False
if tool.auth_required:
if not auth_context:
logger.warning(f"[AC-IDMP-19] Auth required but no context: {tool_name}")
return ToolExecutionResult(
success=False,
error="Authentication required",
duration_ms=int((time.time() - start_time) * 1000),
auth_applied=False,
registry_version=tool.version,
)
auth_applied = True
try:
timeout_seconds = tool.timeout_ms / 1000.0
result = await asyncio.wait_for(
tool.handler(**args) if tool.handler else asyncio.sleep(0),
timeout=timeout_seconds,
)
duration_ms = int((time.time() - start_time) * 1000)
logger.info(
f"[AC-IDMP-19] Tool executed: name={tool_name}, "
f"duration_ms={duration_ms}, success=True"
)
return ToolExecutionResult(
success=True,
output=result,
duration_ms=duration_ms,
auth_applied=auth_applied,
registry_version=tool.version,
)
except asyncio.TimeoutError:
duration_ms = int((time.time() - start_time) * 1000)
logger.warning(
f"[AC-IDMP-19] Tool timeout: name={tool_name}, "
f"duration_ms={duration_ms}"
)
return ToolExecutionResult(
success=False,
error=f"Tool timeout after {tool.timeout_ms}ms",
duration_ms=duration_ms,
auth_applied=auth_applied,
registry_version=tool.version,
)
except Exception as e:
duration_ms = int((time.time() - start_time) * 1000)
logger.error(
f"[AC-IDMP-19] Tool error: name={tool_name}, error={e}"
)
return ToolExecutionResult(
success=False,
error=str(e),
duration_ms=duration_ms,
auth_applied=auth_applied,
registry_version=tool.version,
)
def create_trace(
self,
tool_name: str,
result: ToolExecutionResult,
args_digest: str | None = None,
) -> ToolCallTrace:
"""
[AC-IDMP-19] Create ToolCallTrace from execution result.
"""
tool = self._tools.get(tool_name)
return ToolCallTrace(
tool_name=tool_name,
tool_type=tool.tool_type if tool else ToolType.INTERNAL,
registry_version=result.registry_version,
auth_applied=result.auth_applied,
duration_ms=result.duration_ms,
status=ToolCallStatus.OK if result.success else (
ToolCallStatus.TIMEOUT if "timeout" in (result.error or "").lower()
else ToolCallStatus.ERROR
),
error_code=result.error if not result.success else None,
args_digest=args_digest,
result_digest=str(result.output)[:100] if result.output else None,
)
def get_governance_report(self) -> dict[str, Any]:
"""Get governance report for all tools."""
return {
"registry_version": self._version,
"total_tools": len(self._tools),
"enabled_tools": sum(1 for t in self._tools.values() if t.enabled),
"disabled_tools": sum(1 for t in self._tools.values() if not t.enabled),
"auth_required_tools": sum(1 for t in self._tools.values() if t.auth_required),
"mcp_tools": sum(1 for t in self._tools.values() if t.tool_type == ToolType.MCP),
"internal_tools": sum(1 for t in self._tools.values() if t.tool_type == ToolType.INTERNAL),
"tools": [
{
"name": t.name,
"type": t.tool_type.value,
"version": t.version,
"enabled": t.enabled,
"auth_required": t.auth_required,
"timeout_ms": t.timeout_ms,
}
for t in self._tools.values()
],
}
_registry: ToolRegistry | None = None
def get_tool_registry() -> ToolRegistry:
"""Get global tool registry instance."""
global _registry
if _registry is None:
_registry = ToolRegistry()
return _registry
def init_tool_registry(timeout_governor: TimeoutGovernor | None = None) -> ToolRegistry:
"""Initialize and return tool registry."""
global _registry
_registry = ToolRegistry(timeout_governor=timeout_governor)
return _registry