""" API Key management service. [AC-AISVC-50] Lightweight authentication with in-memory cache. """ from __future__ import annotations import logging import secrets from collections import deque from dataclasses import dataclass, field from datetime import datetime, timedelta from typing import Optional from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from app.models.entities import ApiKey, ApiKeyCreate logger = logging.getLogger(__name__) @dataclass class CachedApiKeyMeta: """Cached metadata for API key policy checks.""" is_active: bool expires_at: datetime | None allowed_ips: set[str] = field(default_factory=set) rate_limit_qpm: int = 60 @dataclass class ValidationResult: """Validation output for middleware auth + policy checks.""" ok: bool reason: str | None = None rate_limit_qpm: int = 60 class ApiKeyService: """ [AC-AISVC-50] API Key management service. Features: - In-memory cache for fast validation - Database persistence - Hot-reload support """ def __init__(self): self._keys_cache: set[str] = set() self._key_meta: dict[str, CachedApiKeyMeta] = {} self._rate_buckets: dict[str, deque[datetime]] = {} self._initialized: bool = False async def initialize(self, session: AsyncSession) -> None: """ Load all active API keys from database into memory. Should be called on application startup. """ try: result = await session.execute( select(ApiKey).where(ApiKey.is_active == True) ) keys = result.scalars().all() self._keys_cache = {key.key for key in keys} self._key_meta = { key.key: CachedApiKeyMeta( is_active=key.is_active, expires_at=key.expires_at, allowed_ips=set(key.allowed_ips or []), rate_limit_qpm=key.rate_limit_qpm or 60, ) for key in keys } self._initialized = True logger.info(f"[AC-AISVC-50] Loaded {len(self._keys_cache)} API keys into memory") return except Exception as e: logger.warning(f"[AC-AISVC-50] Full API key schema load failed, fallback to legacy columns: {e}") await session.rollback() # Backward-compat fallback for environments without new columns try: result = await session.execute( select(ApiKey.key, ApiKey.is_active).where(ApiKey.is_active == True) ) rows = result.all() self._keys_cache = {row[0] for row in rows} self._key_meta = { row[0]: CachedApiKeyMeta( is_active=bool(row[1]), expires_at=None, allowed_ips=set(), rate_limit_qpm=60, ) for row in rows } self._initialized = True logger.info(f"[AC-AISVC-50] Loaded {len(self._keys_cache)} API keys in legacy compatibility mode") except Exception as fallback_error: self._initialized = False logger.error(f"[AC-AISVC-50] API key initialization failed in both full/legacy mode: {fallback_error}") def validate_key(self, key: str) -> bool: """ Validate an API key against the in-memory cache. Args: key: The API key to validate Returns: True if the key is valid, False otherwise """ if not self._initialized: logger.warning("[AC-AISVC-50] API key service not initialized") return False return key in self._keys_cache def validate_key_with_context(self, key: str, client_ip: str | None) -> ValidationResult: """Validate key and policy constraints: expiration, IP allowlist, and per-minute rate.""" if not self._initialized: return ValidationResult(ok=False, reason="service_not_initialized") if key not in self._keys_cache: return ValidationResult(ok=False, reason="invalid_key") meta = self._key_meta.get(key) if not meta or not meta.is_active: return ValidationResult(ok=False, reason="inactive_key") now = datetime.utcnow() if meta.expires_at and now > meta.expires_at: return ValidationResult(ok=False, reason="expired_key") if meta.allowed_ips and client_ip and client_ip not in meta.allowed_ips: return ValidationResult(ok=False, reason="ip_not_allowed") self._evict_stale_rate_entries(key, now) bucket = self._rate_buckets.setdefault(key, deque()) limit = meta.rate_limit_qpm or 60 if len(bucket) >= limit: return ValidationResult(ok=False, reason="rate_limited", rate_limit_qpm=limit) bucket.append(now) return ValidationResult(ok=True, rate_limit_qpm=limit) def _evict_stale_rate_entries(self, key: str, now: datetime) -> None: """Keep only requests in the latest 60 seconds for token bucket emulation.""" bucket = self._rate_buckets.setdefault(key, deque()) threshold = now - timedelta(seconds=60) while bucket and bucket[0] < threshold: bucket.popleft() def generate_key(self) -> str: """ Generate a new secure API key. Returns: A URL-safe random string """ return secrets.token_urlsafe(32) async def create_key( self, session: AsyncSession, key_create: ApiKeyCreate ) -> ApiKey: """ Create a new API key. Args: session: Database session key_create: Key creation data Returns: The created ApiKey entity """ api_key = ApiKey( key=key_create.key, name=key_create.name, is_active=key_create.is_active, expires_at=key_create.expires_at, allowed_ips=key_create.allowed_ips, rate_limit_qpm=key_create.rate_limit_qpm or 60, ) session.add(api_key) await session.commit() await session.refresh(api_key) if api_key.is_active: self._keys_cache.add(api_key.key) self._key_meta[api_key.key] = CachedApiKeyMeta( is_active=api_key.is_active, expires_at=api_key.expires_at, allowed_ips=set(api_key.allowed_ips or []), rate_limit_qpm=api_key.rate_limit_qpm or 60, ) logger.info(f"[AC-AISVC-50] Created API key: {api_key.name}") return api_key async def create_default_key(self, session: AsyncSession) -> Optional[ApiKey]: """ Create a default API key if none exists. Returns: The created ApiKey or None if keys already exist """ try: result = await session.execute(select(ApiKey).limit(1)) existing = result.scalar_one_or_none() except Exception as e: logger.warning(f"[AC-AISVC-50] Full schema query failed in create_default_key, using fallback: {e}") await session.rollback() result = await session.execute(select(ApiKey.key).limit(1)) existing = result.scalar_one_or_none() if existing: return None default_key = secrets.token_urlsafe(32) api_key = ApiKey( key=default_key, name="Default API Key", is_active=True, ) session.add(api_key) await session.commit() await session.refresh(api_key) self._keys_cache.add(api_key.key) self._key_meta[api_key.key] = CachedApiKeyMeta( is_active=api_key.is_active, expires_at=getattr(api_key, 'expires_at', None), allowed_ips=set(getattr(api_key, 'allowed_ips', []) or []), rate_limit_qpm=getattr(api_key, 'rate_limit_qpm', 60) or 60, ) logger.info(f"[AC-AISVC-50] Created default API key: {api_key.key}") return api_key async def delete_key( self, session: AsyncSession, key_id: str ) -> bool: """ Delete an API key. Args: session: Database session key_id: The key ID to delete Returns: True if deleted, False if not found """ import uuid try: key_uuid = uuid.UUID(key_id) except ValueError: return False result = await session.execute( select(ApiKey).where(ApiKey.id == key_uuid) ) api_key = result.scalar_one_or_none() if not api_key: return False key_value = api_key.key await session.delete(api_key) await session.commit() self._keys_cache.discard(key_value) self._key_meta.pop(key_value, None) self._rate_buckets.pop(key_value, None) logger.info(f"[AC-AISVC-50] Deleted API key: {api_key.name}") return True async def toggle_key( self, session: AsyncSession, key_id: str, is_active: bool ) -> Optional[ApiKey]: """ Toggle API key active status. Args: session: Database session key_id: The key ID to toggle is_active: New active status Returns: The updated ApiKey or None if not found """ import uuid try: key_uuid = uuid.UUID(key_id) except ValueError: return None result = await session.execute( select(ApiKey).where(ApiKey.id == key_uuid) ) api_key = result.scalar_one_or_none() if not api_key: return None api_key.is_active = is_active api_key.updated_at = datetime.utcnow() session.add(api_key) await session.commit() await session.refresh(api_key) if is_active: self._keys_cache.add(api_key.key) self._key_meta[api_key.key] = CachedApiKeyMeta( is_active=api_key.is_active, expires_at=api_key.expires_at, allowed_ips=set(api_key.allowed_ips or []), rate_limit_qpm=api_key.rate_limit_qpm or 60, ) else: self._keys_cache.discard(api_key.key) self._key_meta.pop(api_key.key, None) self._rate_buckets.pop(api_key.key, None) logger.info(f"[AC-AISVC-50] Toggled API key {api_key.name}: active={is_active}") return api_key async def list_keys(self, session: AsyncSession) -> list[ApiKey]: """ List all API keys. Args: session: Database session Returns: List of all ApiKey entities """ result = await session.execute(select(ApiKey)) return list(result.scalars().all()) async def reload_cache(self, session: AsyncSession) -> None: """ Reload all API keys from database into memory. """ self._keys_cache.clear() self._key_meta.clear() self._rate_buckets.clear() await self.initialize(session) logger.info("[AC-AISVC-50] API key cache reloaded") _api_key_service: ApiKeyService | None = None def get_api_key_service() -> ApiKeyService: """Get the global API key service instance.""" global _api_key_service if _api_key_service is None: _api_key_service = ApiKeyService() return _api_key_service