diff --git a/ai-service/app/api/admin/api_key.py b/ai-service/app/api/admin/api_key.py index 121dc97..a0ef487 100644 --- a/ai-service/app/api/admin/api_key.py +++ b/ai-service/app/api/admin/api_key.py @@ -4,6 +4,7 @@ API Key management endpoints. """ import logging +from datetime import datetime from typing import Annotated from fastapi import APIRouter, Depends, HTTPException, status @@ -26,6 +27,9 @@ class ApiKeyResponse(BaseModel): key: str = Field(..., description="API key value") name: str = Field(..., description="API key name") is_active: bool = Field(..., description="Whether the key is active") + expires_at: str | None = Field(default=None, description="Expiration time") + allowed_ips: list[str] | None = Field(default=None, description="Optional client IP allowlist") + rate_limit_qpm: int | None = Field(default=60, description="Per-minute quota") created_at: str = Field(..., description="Creation time") updated_at: str = Field(..., description="Last update time") @@ -42,6 +46,9 @@ class CreateApiKeyRequest(BaseModel): name: str = Field(..., description="API key name/description") key: str | None = Field(default=None, description="Custom API key (auto-generated if not provided)") + expires_at: datetime | None = Field(default=None, description="Expiration time; null means never expires") + allowed_ips: list[str] | None = Field(default=None, description="Optional client IP allowlist") + rate_limit_qpm: int | None = Field(default=60, ge=1, le=60000, description="Per-minute quota") class ToggleApiKeyRequest(BaseModel): @@ -57,6 +64,9 @@ def api_key_to_response(api_key: ApiKey) -> ApiKeyResponse: key=api_key.key, name=api_key.name, is_active=api_key.is_active, + expires_at=api_key.expires_at.isoformat() if api_key.expires_at else None, + allowed_ips=api_key.allowed_ips, + rate_limit_qpm=api_key.rate_limit_qpm, created_at=api_key.created_at.isoformat(), updated_at=api_key.updated_at.isoformat(), ) @@ -94,6 +104,9 @@ async def create_api_key( key=key_value, name=request.name, is_active=True, + expires_at=request.expires_at, + allowed_ips=request.allowed_ips, + rate_limit_qpm=request.rate_limit_qpm, ) api_key = await service.create_key(session, key_create) diff --git a/ai-service/app/models/entities.py b/ai-service/app/models/entities.py index b88c214..8e3edbf 100644 --- a/ai-service/app/models/entities.py +++ b/ai-service/app/models/entities.py @@ -294,6 +294,13 @@ class ApiKey(SQLModel, table=True): key: str = Field(..., description="API Key (unique)", unique=True, index=True) name: str = Field(..., description="Key name/description for identification") is_active: bool = Field(default=True, description="Whether the key is active") + expires_at: datetime | None = Field(default=None, description="Expiration time; null means never expires") + allowed_ips: list[str] | None = Field( + default=None, + sa_column=Column("allowed_ips", JSON, nullable=True), + description="Optional IP allowlist for this key", + ) + rate_limit_qpm: int | None = Field(default=60, description="Per-minute quota for this key") created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time") updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time") @@ -304,6 +311,9 @@ class ApiKeyCreate(SQLModel): key: str name: str is_active: bool = True + expires_at: datetime | None = None + allowed_ips: list[str] | None = None + rate_limit_qpm: int | None = 60 class TemplateVersionStatus(str, Enum): diff --git a/ai-service/app/services/api_key.py b/ai-service/app/services/api_key.py index 51edcdd..fb0439a 100644 --- a/ai-service/app/services/api_key.py +++ b/ai-service/app/services/api_key.py @@ -3,9 +3,13 @@ API Key management service. [AC-AISVC-50] Lightweight authentication with in-memory cache. """ +from __future__ import annotations + import logging import secrets -from datetime import datetime +from collections import deque +from dataclasses import dataclass, field +from datetime import datetime, timedelta from typing import Optional from sqlalchemy import select @@ -16,6 +20,25 @@ from app.models.entities import ApiKey, ApiKeyCreate logger = logging.getLogger(__name__) +@dataclass +class CachedApiKeyMeta: + """Cached metadata for API key policy checks.""" + + is_active: bool + expires_at: datetime | None + allowed_ips: set[str] = field(default_factory=set) + rate_limit_qpm: int = 60 + + +@dataclass +class ValidationResult: + """Validation output for middleware auth + policy checks.""" + + ok: bool + reason: str | None = None + rate_limit_qpm: int = 60 + + class ApiKeyService: """ [AC-AISVC-50] API Key management service. @@ -28,6 +51,8 @@ class ApiKeyService: def __init__(self): self._keys_cache: set[str] = set() + self._key_meta: dict[str, CachedApiKeyMeta] = {} + self._rate_buckets: dict[str, deque[datetime]] = {} self._initialized: bool = False async def initialize(self, session: AsyncSession) -> None: @@ -35,15 +60,50 @@ class ApiKeyService: Load all active API keys from database into memory. Should be called on application startup. """ - result = await session.execute( - select(ApiKey).where(ApiKey.is_active == True) - ) - keys = result.scalars().all() + try: + result = await session.execute( + select(ApiKey).where(ApiKey.is_active == True) + ) + keys = result.scalars().all() - self._keys_cache = {key.key for key in keys} - self._initialized = True + self._keys_cache = {key.key for key in keys} + self._key_meta = { + key.key: CachedApiKeyMeta( + is_active=key.is_active, + expires_at=key.expires_at, + allowed_ips=set(key.allowed_ips or []), + rate_limit_qpm=key.rate_limit_qpm or 60, + ) + for key in keys + } + self._initialized = True + logger.info(f"[AC-AISVC-50] Loaded {len(self._keys_cache)} API keys into memory") + return + except Exception as e: + logger.warning(f"[AC-AISVC-50] Full API key schema load failed, fallback to legacy columns: {e}") + await session.rollback() - logger.info(f"[AC-AISVC-50] Loaded {len(self._keys_cache)} API keys into memory") + # Backward-compat fallback for environments without new columns + try: + result = await session.execute( + select(ApiKey.key, ApiKey.is_active).where(ApiKey.is_active == True) + ) + rows = result.all() + self._keys_cache = {row[0] for row in rows} + self._key_meta = { + row[0]: CachedApiKeyMeta( + is_active=bool(row[1]), + expires_at=None, + allowed_ips=set(), + rate_limit_qpm=60, + ) + for row in rows + } + self._initialized = True + logger.info(f"[AC-AISVC-50] Loaded {len(self._keys_cache)} API keys in legacy compatibility mode") + except Exception as fallback_error: + self._initialized = False + logger.error(f"[AC-AISVC-50] API key initialization failed in both full/legacy mode: {fallback_error}") def validate_key(self, key: str) -> bool: """ @@ -61,6 +121,41 @@ class ApiKeyService: return key in self._keys_cache + def validate_key_with_context(self, key: str, client_ip: str | None) -> ValidationResult: + """Validate key and policy constraints: expiration, IP allowlist, and per-minute rate.""" + if not self._initialized: + return ValidationResult(ok=False, reason="service_not_initialized") + + if key not in self._keys_cache: + return ValidationResult(ok=False, reason="invalid_key") + + meta = self._key_meta.get(key) + if not meta or not meta.is_active: + return ValidationResult(ok=False, reason="inactive_key") + + now = datetime.utcnow() + if meta.expires_at and now > meta.expires_at: + return ValidationResult(ok=False, reason="expired_key") + + if meta.allowed_ips and client_ip and client_ip not in meta.allowed_ips: + return ValidationResult(ok=False, reason="ip_not_allowed") + + self._evict_stale_rate_entries(key, now) + bucket = self._rate_buckets.setdefault(key, deque()) + limit = meta.rate_limit_qpm or 60 + if len(bucket) >= limit: + return ValidationResult(ok=False, reason="rate_limited", rate_limit_qpm=limit) + + bucket.append(now) + return ValidationResult(ok=True, rate_limit_qpm=limit) + + def _evict_stale_rate_entries(self, key: str, now: datetime) -> None: + """Keep only requests in the latest 60 seconds for token bucket emulation.""" + bucket = self._rate_buckets.setdefault(key, deque()) + threshold = now - timedelta(seconds=60) + while bucket and bucket[0] < threshold: + bucket.popleft() + def generate_key(self) -> str: """ Generate a new secure API key. @@ -89,6 +184,9 @@ class ApiKeyService: key=key_create.key, name=key_create.name, is_active=key_create.is_active, + expires_at=key_create.expires_at, + allowed_ips=key_create.allowed_ips, + rate_limit_qpm=key_create.rate_limit_qpm or 60, ) session.add(api_key) @@ -97,6 +195,12 @@ class ApiKeyService: if api_key.is_active: self._keys_cache.add(api_key.key) + self._key_meta[api_key.key] = CachedApiKeyMeta( + is_active=api_key.is_active, + expires_at=api_key.expires_at, + allowed_ips=set(api_key.allowed_ips or []), + rate_limit_qpm=api_key.rate_limit_qpm or 60, + ) logger.info(f"[AC-AISVC-50] Created API key: {api_key.name}") return api_key @@ -108,8 +212,14 @@ class ApiKeyService: Returns: The created ApiKey or None if keys already exist """ - result = await session.execute(select(ApiKey).limit(1)) - existing = result.scalar_one_or_none() + try: + result = await session.execute(select(ApiKey).limit(1)) + existing = result.scalar_one_or_none() + except Exception as e: + logger.warning(f"[AC-AISVC-50] Full schema query failed in create_default_key, using fallback: {e}") + await session.rollback() + result = await session.execute(select(ApiKey.key).limit(1)) + existing = result.scalar_one_or_none() if existing: return None @@ -126,6 +236,12 @@ class ApiKeyService: await session.refresh(api_key) self._keys_cache.add(api_key.key) + self._key_meta[api_key.key] = CachedApiKeyMeta( + is_active=api_key.is_active, + expires_at=getattr(api_key, 'expires_at', None), + allowed_ips=set(getattr(api_key, 'allowed_ips', []) or []), + rate_limit_qpm=getattr(api_key, 'rate_limit_qpm', 60) or 60, + ) logger.info(f"[AC-AISVC-50] Created default API key: {api_key.key}") return api_key @@ -165,6 +281,8 @@ class ApiKeyService: await session.commit() self._keys_cache.discard(key_value) + self._key_meta.pop(key_value, None) + self._rate_buckets.pop(key_value, None) logger.info(f"[AC-AISVC-50] Deleted API key: {api_key.name}") return True @@ -210,8 +328,16 @@ class ApiKeyService: if is_active: self._keys_cache.add(api_key.key) + self._key_meta[api_key.key] = CachedApiKeyMeta( + is_active=api_key.is_active, + expires_at=api_key.expires_at, + allowed_ips=set(api_key.allowed_ips or []), + rate_limit_qpm=api_key.rate_limit_qpm or 60, + ) else: self._keys_cache.discard(api_key.key) + self._key_meta.pop(api_key.key, None) + self._rate_buckets.pop(api_key.key, None) logger.info(f"[AC-AISVC-50] Toggled API key {api_key.name}: active={is_active}") return api_key @@ -234,6 +360,8 @@ class ApiKeyService: Reload all API keys from database into memory. """ self._keys_cache.clear() + self._key_meta.clear() + self._rate_buckets.clear() await self.initialize(session) logger.info("[AC-AISVC-50] API key cache reloaded")