ai-robot-core/ai-service/app/services/api_key.py

377 lines
11 KiB
Python

"""
API Key management service.
[AC-AISVC-50] Lightweight authentication with in-memory cache.
"""
from __future__ import annotations
import logging
import secrets
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models.entities import ApiKey, ApiKeyCreate
logger = logging.getLogger(__name__)
@dataclass
class CachedApiKeyMeta:
"""Cached metadata for API key policy checks."""
is_active: bool
expires_at: datetime | None
allowed_ips: set[str] = field(default_factory=set)
rate_limit_qpm: int = 60
@dataclass
class ValidationResult:
"""Validation output for middleware auth + policy checks."""
ok: bool
reason: str | None = None
rate_limit_qpm: int = 60
class ApiKeyService:
"""
[AC-AISVC-50] API Key management service.
Features:
- In-memory cache for fast validation
- Database persistence
- Hot-reload support
"""
def __init__(self):
self._keys_cache: set[str] = set()
self._key_meta: dict[str, CachedApiKeyMeta] = {}
self._rate_buckets: dict[str, deque[datetime]] = {}
self._initialized: bool = False
async def initialize(self, session: AsyncSession) -> None:
"""
Load all active API keys from database into memory.
Should be called on application startup.
"""
try:
result = await session.execute(
select(ApiKey).where(ApiKey.is_active == True)
)
keys = result.scalars().all()
self._keys_cache = {key.key for key in keys}
self._key_meta = {
key.key: CachedApiKeyMeta(
is_active=key.is_active,
expires_at=key.expires_at,
allowed_ips=set(key.allowed_ips or []),
rate_limit_qpm=key.rate_limit_qpm or 60,
)
for key in keys
}
self._initialized = True
logger.info(f"[AC-AISVC-50] Loaded {len(self._keys_cache)} API keys into memory")
return
except Exception as e:
logger.warning(f"[AC-AISVC-50] Full API key schema load failed, fallback to legacy columns: {e}")
# Backward-compat fallback for environments without new columns
try:
result = await session.execute(
select(ApiKey.key, ApiKey.is_active).where(ApiKey.is_active == True)
)
rows = result.all()
self._keys_cache = {row[0] for row in rows}
self._key_meta = {
row[0]: CachedApiKeyMeta(
is_active=bool(row[1]),
expires_at=None,
allowed_ips=set(),
rate_limit_qpm=60,
)
for row in rows
}
self._initialized = True
logger.info(f"[AC-AISVC-50] Loaded {len(self._keys_cache)} API keys in legacy compatibility mode")
except Exception as fallback_error:
self._initialized = False
logger.error(f"[AC-AISVC-50] API key initialization failed in both full/legacy mode: {fallback_error}")
def validate_key(self, key: str) -> bool:
"""
Validate an API key against the in-memory cache.
Args:
key: The API key to validate
Returns:
True if the key is valid, False otherwise
"""
if not self._initialized:
logger.warning("[AC-AISVC-50] API key service not initialized")
return False
return key in self._keys_cache
def validate_key_with_context(self, key: str, client_ip: str | None) -> ValidationResult:
"""Validate key and policy constraints: expiration, IP allowlist, and per-minute rate."""
if not self._initialized:
return ValidationResult(ok=False, reason="service_not_initialized")
if key not in self._keys_cache:
return ValidationResult(ok=False, reason="invalid_key")
meta = self._key_meta.get(key)
if not meta or not meta.is_active:
return ValidationResult(ok=False, reason="inactive_key")
now = datetime.utcnow()
if meta.expires_at and now > meta.expires_at:
return ValidationResult(ok=False, reason="expired_key")
if meta.allowed_ips and client_ip and client_ip not in meta.allowed_ips:
return ValidationResult(ok=False, reason="ip_not_allowed")
self._evict_stale_rate_entries(key, now)
bucket = self._rate_buckets.setdefault(key, deque())
limit = meta.rate_limit_qpm or 60
if len(bucket) >= limit:
return ValidationResult(ok=False, reason="rate_limited", rate_limit_qpm=limit)
bucket.append(now)
return ValidationResult(ok=True, rate_limit_qpm=limit)
def _evict_stale_rate_entries(self, key: str, now: datetime) -> None:
"""Keep only requests in the latest 60 seconds for token bucket emulation."""
bucket = self._rate_buckets.setdefault(key, deque())
threshold = now - timedelta(seconds=60)
while bucket and bucket[0] < threshold:
bucket.popleft()
def generate_key(self) -> str:
"""
Generate a new secure API key.
Returns:
A URL-safe random string
"""
return secrets.token_urlsafe(32)
async def create_key(
self,
session: AsyncSession,
key_create: ApiKeyCreate
) -> ApiKey:
"""
Create a new API key.
Args:
session: Database session
key_create: Key creation data
Returns:
The created ApiKey entity
"""
api_key = ApiKey(
key=key_create.key,
name=key_create.name,
is_active=key_create.is_active,
expires_at=key_create.expires_at,
allowed_ips=key_create.allowed_ips,
rate_limit_qpm=key_create.rate_limit_qpm or 60,
)
session.add(api_key)
await session.commit()
await session.refresh(api_key)
if api_key.is_active:
self._keys_cache.add(api_key.key)
self._key_meta[api_key.key] = CachedApiKeyMeta(
is_active=api_key.is_active,
expires_at=api_key.expires_at,
allowed_ips=set(api_key.allowed_ips or []),
rate_limit_qpm=api_key.rate_limit_qpm or 60,
)
logger.info(f"[AC-AISVC-50] Created API key: {api_key.name}")
return api_key
async def create_default_key(self, session: AsyncSession) -> Optional[ApiKey]:
"""
Create a default API key if none exists.
Returns:
The created ApiKey or None if keys already exist
"""
try:
result = await session.execute(select(ApiKey).limit(1))
existing = result.scalar_one_or_none()
except Exception as e:
logger.warning(f"[AC-AISVC-50] Full schema query failed in create_default_key, using fallback: {e}")
await session.rollback()
result = await session.execute(select(ApiKey.key).limit(1))
existing = result.scalar_one_or_none()
if existing:
return None
default_key = secrets.token_urlsafe(32)
api_key = ApiKey(
key=default_key,
name="Default API Key",
is_active=True,
)
session.add(api_key)
await session.commit()
await session.refresh(api_key)
self._keys_cache.add(api_key.key)
self._key_meta[api_key.key] = CachedApiKeyMeta(
is_active=api_key.is_active,
expires_at=getattr(api_key, 'expires_at', None),
allowed_ips=set(getattr(api_key, 'allowed_ips', []) or []),
rate_limit_qpm=getattr(api_key, 'rate_limit_qpm', 60) or 60,
)
logger.info(f"[AC-AISVC-50] Created default API key: {api_key.key}")
return api_key
async def delete_key(
self,
session: AsyncSession,
key_id: str
) -> bool:
"""
Delete an API key.
Args:
session: Database session
key_id: The key ID to delete
Returns:
True if deleted, False if not found
"""
import uuid
try:
key_uuid = uuid.UUID(key_id)
except ValueError:
return False
result = await session.execute(
select(ApiKey).where(ApiKey.id == key_uuid)
)
api_key = result.scalar_one_or_none()
if not api_key:
return False
key_value = api_key.key
await session.delete(api_key)
await session.commit()
self._keys_cache.discard(key_value)
self._key_meta.pop(key_value, None)
self._rate_buckets.pop(key_value, None)
logger.info(f"[AC-AISVC-50] Deleted API key: {api_key.name}")
return True
async def toggle_key(
self,
session: AsyncSession,
key_id: str,
is_active: bool
) -> Optional[ApiKey]:
"""
Toggle API key active status.
Args:
session: Database session
key_id: The key ID to toggle
is_active: New active status
Returns:
The updated ApiKey or None if not found
"""
import uuid
try:
key_uuid = uuid.UUID(key_id)
except ValueError:
return None
result = await session.execute(
select(ApiKey).where(ApiKey.id == key_uuid)
)
api_key = result.scalar_one_or_none()
if not api_key:
return None
api_key.is_active = is_active
api_key.updated_at = datetime.utcnow()
session.add(api_key)
await session.commit()
await session.refresh(api_key)
if is_active:
self._keys_cache.add(api_key.key)
self._key_meta[api_key.key] = CachedApiKeyMeta(
is_active=api_key.is_active,
expires_at=api_key.expires_at,
allowed_ips=set(api_key.allowed_ips or []),
rate_limit_qpm=api_key.rate_limit_qpm or 60,
)
else:
self._keys_cache.discard(api_key.key)
self._key_meta.pop(api_key.key, None)
self._rate_buckets.pop(api_key.key, None)
logger.info(f"[AC-AISVC-50] Toggled API key {api_key.name}: active={is_active}")
return api_key
async def list_keys(self, session: AsyncSession) -> list[ApiKey]:
"""
List all API keys.
Args:
session: Database session
Returns:
List of all ApiKey entities
"""
result = await session.execute(select(ApiKey))
return list(result.scalars().all())
async def reload_cache(self, session: AsyncSession) -> None:
"""
Reload all API keys from database into memory.
"""
self._keys_cache.clear()
self._key_meta.clear()
self._rate_buckets.clear()
await self.initialize(session)
logger.info("[AC-AISVC-50] API key cache reloaded")
_api_key_service: ApiKeyService | None = None
def get_api_key_service() -> ApiKeyService:
"""Get the global API key service instance."""
global _api_key_service
if _api_key_service is None:
_api_key_service = ApiKeyService()
return _api_key_service