feat: add API key management with entity model and service layer [AC-AISVC-APIKEY]
This commit is contained in:
parent
5f4bde8752
commit
f823e8fb86
|
|
@ -4,6 +4,7 @@ API Key management endpoints.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from datetime import datetime
|
||||||
from typing import Annotated
|
from typing import Annotated
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
|
@ -26,6 +27,9 @@ class ApiKeyResponse(BaseModel):
|
||||||
key: str = Field(..., description="API key value")
|
key: str = Field(..., description="API key value")
|
||||||
name: str = Field(..., description="API key name")
|
name: str = Field(..., description="API key name")
|
||||||
is_active: bool = Field(..., description="Whether the key is active")
|
is_active: bool = Field(..., description="Whether the key is active")
|
||||||
|
expires_at: str | None = Field(default=None, description="Expiration time")
|
||||||
|
allowed_ips: list[str] | None = Field(default=None, description="Optional client IP allowlist")
|
||||||
|
rate_limit_qpm: int | None = Field(default=60, description="Per-minute quota")
|
||||||
created_at: str = Field(..., description="Creation time")
|
created_at: str = Field(..., description="Creation time")
|
||||||
updated_at: str = Field(..., description="Last update time")
|
updated_at: str = Field(..., description="Last update time")
|
||||||
|
|
||||||
|
|
@ -42,6 +46,9 @@ class CreateApiKeyRequest(BaseModel):
|
||||||
|
|
||||||
name: str = Field(..., description="API key name/description")
|
name: str = Field(..., description="API key name/description")
|
||||||
key: str | None = Field(default=None, description="Custom API key (auto-generated if not provided)")
|
key: str | None = Field(default=None, description="Custom API key (auto-generated if not provided)")
|
||||||
|
expires_at: datetime | None = Field(default=None, description="Expiration time; null means never expires")
|
||||||
|
allowed_ips: list[str] | None = Field(default=None, description="Optional client IP allowlist")
|
||||||
|
rate_limit_qpm: int | None = Field(default=60, ge=1, le=60000, description="Per-minute quota")
|
||||||
|
|
||||||
|
|
||||||
class ToggleApiKeyRequest(BaseModel):
|
class ToggleApiKeyRequest(BaseModel):
|
||||||
|
|
@ -57,6 +64,9 @@ def api_key_to_response(api_key: ApiKey) -> ApiKeyResponse:
|
||||||
key=api_key.key,
|
key=api_key.key,
|
||||||
name=api_key.name,
|
name=api_key.name,
|
||||||
is_active=api_key.is_active,
|
is_active=api_key.is_active,
|
||||||
|
expires_at=api_key.expires_at.isoformat() if api_key.expires_at else None,
|
||||||
|
allowed_ips=api_key.allowed_ips,
|
||||||
|
rate_limit_qpm=api_key.rate_limit_qpm,
|
||||||
created_at=api_key.created_at.isoformat(),
|
created_at=api_key.created_at.isoformat(),
|
||||||
updated_at=api_key.updated_at.isoformat(),
|
updated_at=api_key.updated_at.isoformat(),
|
||||||
)
|
)
|
||||||
|
|
@ -94,6 +104,9 @@ async def create_api_key(
|
||||||
key=key_value,
|
key=key_value,
|
||||||
name=request.name,
|
name=request.name,
|
||||||
is_active=True,
|
is_active=True,
|
||||||
|
expires_at=request.expires_at,
|
||||||
|
allowed_ips=request.allowed_ips,
|
||||||
|
rate_limit_qpm=request.rate_limit_qpm,
|
||||||
)
|
)
|
||||||
|
|
||||||
api_key = await service.create_key(session, key_create)
|
api_key = await service.create_key(session, key_create)
|
||||||
|
|
|
||||||
|
|
@ -294,6 +294,13 @@ class ApiKey(SQLModel, table=True):
|
||||||
key: str = Field(..., description="API Key (unique)", unique=True, index=True)
|
key: str = Field(..., description="API Key (unique)", unique=True, index=True)
|
||||||
name: str = Field(..., description="Key name/description for identification")
|
name: str = Field(..., description="Key name/description for identification")
|
||||||
is_active: bool = Field(default=True, description="Whether the key is active")
|
is_active: bool = Field(default=True, description="Whether the key is active")
|
||||||
|
expires_at: datetime | None = Field(default=None, description="Expiration time; null means never expires")
|
||||||
|
allowed_ips: list[str] | None = Field(
|
||||||
|
default=None,
|
||||||
|
sa_column=Column("allowed_ips", JSON, nullable=True),
|
||||||
|
description="Optional IP allowlist for this key",
|
||||||
|
)
|
||||||
|
rate_limit_qpm: int | None = Field(default=60, description="Per-minute quota for this key")
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
|
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
|
||||||
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
|
||||||
|
|
||||||
|
|
@ -304,6 +311,9 @@ class ApiKeyCreate(SQLModel):
|
||||||
key: str
|
key: str
|
||||||
name: str
|
name: str
|
||||||
is_active: bool = True
|
is_active: bool = True
|
||||||
|
expires_at: datetime | None = None
|
||||||
|
allowed_ips: list[str] | None = None
|
||||||
|
rate_limit_qpm: int | None = 60
|
||||||
|
|
||||||
|
|
||||||
class TemplateVersionStatus(str, Enum):
|
class TemplateVersionStatus(str, Enum):
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,13 @@ API Key management service.
|
||||||
[AC-AISVC-50] Lightweight authentication with in-memory cache.
|
[AC-AISVC-50] Lightweight authentication with in-memory cache.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
from datetime import datetime
|
from collections import deque
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime, timedelta
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
@ -16,6 +20,25 @@ from app.models.entities import ApiKey, ApiKeyCreate
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CachedApiKeyMeta:
|
||||||
|
"""Cached metadata for API key policy checks."""
|
||||||
|
|
||||||
|
is_active: bool
|
||||||
|
expires_at: datetime | None
|
||||||
|
allowed_ips: set[str] = field(default_factory=set)
|
||||||
|
rate_limit_qpm: int = 60
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ValidationResult:
|
||||||
|
"""Validation output for middleware auth + policy checks."""
|
||||||
|
|
||||||
|
ok: bool
|
||||||
|
reason: str | None = None
|
||||||
|
rate_limit_qpm: int = 60
|
||||||
|
|
||||||
|
|
||||||
class ApiKeyService:
|
class ApiKeyService:
|
||||||
"""
|
"""
|
||||||
[AC-AISVC-50] API Key management service.
|
[AC-AISVC-50] API Key management service.
|
||||||
|
|
@ -28,6 +51,8 @@ class ApiKeyService:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._keys_cache: set[str] = set()
|
self._keys_cache: set[str] = set()
|
||||||
|
self._key_meta: dict[str, CachedApiKeyMeta] = {}
|
||||||
|
self._rate_buckets: dict[str, deque[datetime]] = {}
|
||||||
self._initialized: bool = False
|
self._initialized: bool = False
|
||||||
|
|
||||||
async def initialize(self, session: AsyncSession) -> None:
|
async def initialize(self, session: AsyncSession) -> None:
|
||||||
|
|
@ -35,15 +60,50 @@ class ApiKeyService:
|
||||||
Load all active API keys from database into memory.
|
Load all active API keys from database into memory.
|
||||||
Should be called on application startup.
|
Should be called on application startup.
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
result = await session.execute(
|
result = await session.execute(
|
||||||
select(ApiKey).where(ApiKey.is_active == True)
|
select(ApiKey).where(ApiKey.is_active == True)
|
||||||
)
|
)
|
||||||
keys = result.scalars().all()
|
keys = result.scalars().all()
|
||||||
|
|
||||||
self._keys_cache = {key.key for key in keys}
|
self._keys_cache = {key.key for key in keys}
|
||||||
|
self._key_meta = {
|
||||||
|
key.key: CachedApiKeyMeta(
|
||||||
|
is_active=key.is_active,
|
||||||
|
expires_at=key.expires_at,
|
||||||
|
allowed_ips=set(key.allowed_ips or []),
|
||||||
|
rate_limit_qpm=key.rate_limit_qpm or 60,
|
||||||
|
)
|
||||||
|
for key in keys
|
||||||
|
}
|
||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
logger.info(f"[AC-AISVC-50] Loaded {len(self._keys_cache)} API keys into memory")
|
logger.info(f"[AC-AISVC-50] Loaded {len(self._keys_cache)} API keys into memory")
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[AC-AISVC-50] Full API key schema load failed, fallback to legacy columns: {e}")
|
||||||
|
await session.rollback()
|
||||||
|
|
||||||
|
# Backward-compat fallback for environments without new columns
|
||||||
|
try:
|
||||||
|
result = await session.execute(
|
||||||
|
select(ApiKey.key, ApiKey.is_active).where(ApiKey.is_active == True)
|
||||||
|
)
|
||||||
|
rows = result.all()
|
||||||
|
self._keys_cache = {row[0] for row in rows}
|
||||||
|
self._key_meta = {
|
||||||
|
row[0]: CachedApiKeyMeta(
|
||||||
|
is_active=bool(row[1]),
|
||||||
|
expires_at=None,
|
||||||
|
allowed_ips=set(),
|
||||||
|
rate_limit_qpm=60,
|
||||||
|
)
|
||||||
|
for row in rows
|
||||||
|
}
|
||||||
|
self._initialized = True
|
||||||
|
logger.info(f"[AC-AISVC-50] Loaded {len(self._keys_cache)} API keys in legacy compatibility mode")
|
||||||
|
except Exception as fallback_error:
|
||||||
|
self._initialized = False
|
||||||
|
logger.error(f"[AC-AISVC-50] API key initialization failed in both full/legacy mode: {fallback_error}")
|
||||||
|
|
||||||
def validate_key(self, key: str) -> bool:
|
def validate_key(self, key: str) -> bool:
|
||||||
"""
|
"""
|
||||||
|
|
@ -61,6 +121,41 @@ class ApiKeyService:
|
||||||
|
|
||||||
return key in self._keys_cache
|
return key in self._keys_cache
|
||||||
|
|
||||||
|
def validate_key_with_context(self, key: str, client_ip: str | None) -> ValidationResult:
|
||||||
|
"""Validate key and policy constraints: expiration, IP allowlist, and per-minute rate."""
|
||||||
|
if not self._initialized:
|
||||||
|
return ValidationResult(ok=False, reason="service_not_initialized")
|
||||||
|
|
||||||
|
if key not in self._keys_cache:
|
||||||
|
return ValidationResult(ok=False, reason="invalid_key")
|
||||||
|
|
||||||
|
meta = self._key_meta.get(key)
|
||||||
|
if not meta or not meta.is_active:
|
||||||
|
return ValidationResult(ok=False, reason="inactive_key")
|
||||||
|
|
||||||
|
now = datetime.utcnow()
|
||||||
|
if meta.expires_at and now > meta.expires_at:
|
||||||
|
return ValidationResult(ok=False, reason="expired_key")
|
||||||
|
|
||||||
|
if meta.allowed_ips and client_ip and client_ip not in meta.allowed_ips:
|
||||||
|
return ValidationResult(ok=False, reason="ip_not_allowed")
|
||||||
|
|
||||||
|
self._evict_stale_rate_entries(key, now)
|
||||||
|
bucket = self._rate_buckets.setdefault(key, deque())
|
||||||
|
limit = meta.rate_limit_qpm or 60
|
||||||
|
if len(bucket) >= limit:
|
||||||
|
return ValidationResult(ok=False, reason="rate_limited", rate_limit_qpm=limit)
|
||||||
|
|
||||||
|
bucket.append(now)
|
||||||
|
return ValidationResult(ok=True, rate_limit_qpm=limit)
|
||||||
|
|
||||||
|
def _evict_stale_rate_entries(self, key: str, now: datetime) -> None:
|
||||||
|
"""Keep only requests in the latest 60 seconds for token bucket emulation."""
|
||||||
|
bucket = self._rate_buckets.setdefault(key, deque())
|
||||||
|
threshold = now - timedelta(seconds=60)
|
||||||
|
while bucket and bucket[0] < threshold:
|
||||||
|
bucket.popleft()
|
||||||
|
|
||||||
def generate_key(self) -> str:
|
def generate_key(self) -> str:
|
||||||
"""
|
"""
|
||||||
Generate a new secure API key.
|
Generate a new secure API key.
|
||||||
|
|
@ -89,6 +184,9 @@ class ApiKeyService:
|
||||||
key=key_create.key,
|
key=key_create.key,
|
||||||
name=key_create.name,
|
name=key_create.name,
|
||||||
is_active=key_create.is_active,
|
is_active=key_create.is_active,
|
||||||
|
expires_at=key_create.expires_at,
|
||||||
|
allowed_ips=key_create.allowed_ips,
|
||||||
|
rate_limit_qpm=key_create.rate_limit_qpm or 60,
|
||||||
)
|
)
|
||||||
|
|
||||||
session.add(api_key)
|
session.add(api_key)
|
||||||
|
|
@ -97,6 +195,12 @@ class ApiKeyService:
|
||||||
|
|
||||||
if api_key.is_active:
|
if api_key.is_active:
|
||||||
self._keys_cache.add(api_key.key)
|
self._keys_cache.add(api_key.key)
|
||||||
|
self._key_meta[api_key.key] = CachedApiKeyMeta(
|
||||||
|
is_active=api_key.is_active,
|
||||||
|
expires_at=api_key.expires_at,
|
||||||
|
allowed_ips=set(api_key.allowed_ips or []),
|
||||||
|
rate_limit_qpm=api_key.rate_limit_qpm or 60,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"[AC-AISVC-50] Created API key: {api_key.name}")
|
logger.info(f"[AC-AISVC-50] Created API key: {api_key.name}")
|
||||||
return api_key
|
return api_key
|
||||||
|
|
@ -108,8 +212,14 @@ class ApiKeyService:
|
||||||
Returns:
|
Returns:
|
||||||
The created ApiKey or None if keys already exist
|
The created ApiKey or None if keys already exist
|
||||||
"""
|
"""
|
||||||
|
try:
|
||||||
result = await session.execute(select(ApiKey).limit(1))
|
result = await session.execute(select(ApiKey).limit(1))
|
||||||
existing = result.scalar_one_or_none()
|
existing = result.scalar_one_or_none()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[AC-AISVC-50] Full schema query failed in create_default_key, using fallback: {e}")
|
||||||
|
await session.rollback()
|
||||||
|
result = await session.execute(select(ApiKey.key).limit(1))
|
||||||
|
existing = result.scalar_one_or_none()
|
||||||
|
|
||||||
if existing:
|
if existing:
|
||||||
return None
|
return None
|
||||||
|
|
@ -126,6 +236,12 @@ class ApiKeyService:
|
||||||
await session.refresh(api_key)
|
await session.refresh(api_key)
|
||||||
|
|
||||||
self._keys_cache.add(api_key.key)
|
self._keys_cache.add(api_key.key)
|
||||||
|
self._key_meta[api_key.key] = CachedApiKeyMeta(
|
||||||
|
is_active=api_key.is_active,
|
||||||
|
expires_at=getattr(api_key, 'expires_at', None),
|
||||||
|
allowed_ips=set(getattr(api_key, 'allowed_ips', []) or []),
|
||||||
|
rate_limit_qpm=getattr(api_key, 'rate_limit_qpm', 60) or 60,
|
||||||
|
)
|
||||||
|
|
||||||
logger.info(f"[AC-AISVC-50] Created default API key: {api_key.key}")
|
logger.info(f"[AC-AISVC-50] Created default API key: {api_key.key}")
|
||||||
return api_key
|
return api_key
|
||||||
|
|
@ -165,6 +281,8 @@ class ApiKeyService:
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
self._keys_cache.discard(key_value)
|
self._keys_cache.discard(key_value)
|
||||||
|
self._key_meta.pop(key_value, None)
|
||||||
|
self._rate_buckets.pop(key_value, None)
|
||||||
|
|
||||||
logger.info(f"[AC-AISVC-50] Deleted API key: {api_key.name}")
|
logger.info(f"[AC-AISVC-50] Deleted API key: {api_key.name}")
|
||||||
return True
|
return True
|
||||||
|
|
@ -210,8 +328,16 @@ class ApiKeyService:
|
||||||
|
|
||||||
if is_active:
|
if is_active:
|
||||||
self._keys_cache.add(api_key.key)
|
self._keys_cache.add(api_key.key)
|
||||||
|
self._key_meta[api_key.key] = CachedApiKeyMeta(
|
||||||
|
is_active=api_key.is_active,
|
||||||
|
expires_at=api_key.expires_at,
|
||||||
|
allowed_ips=set(api_key.allowed_ips or []),
|
||||||
|
rate_limit_qpm=api_key.rate_limit_qpm or 60,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self._keys_cache.discard(api_key.key)
|
self._keys_cache.discard(api_key.key)
|
||||||
|
self._key_meta.pop(api_key.key, None)
|
||||||
|
self._rate_buckets.pop(api_key.key, None)
|
||||||
|
|
||||||
logger.info(f"[AC-AISVC-50] Toggled API key {api_key.name}: active={is_active}")
|
logger.info(f"[AC-AISVC-50] Toggled API key {api_key.name}: active={is_active}")
|
||||||
return api_key
|
return api_key
|
||||||
|
|
@ -234,6 +360,8 @@ class ApiKeyService:
|
||||||
Reload all API keys from database into memory.
|
Reload all API keys from database into memory.
|
||||||
"""
|
"""
|
||||||
self._keys_cache.clear()
|
self._keys_cache.clear()
|
||||||
|
self._key_meta.clear()
|
||||||
|
self._rate_buckets.clear()
|
||||||
await self.initialize(session)
|
await self.initialize(session)
|
||||||
logger.info("[AC-AISVC-50] API key cache reloaded")
|
logger.info("[AC-AISVC-50] API key cache reloaded")
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue