feat: add API key management with entity model and service layer [AC-AISVC-APIKEY]

This commit is contained in:
MerCry 2026-03-06 01:10:42 +08:00
parent 5f4bde8752
commit f823e8fb86
3 changed files with 161 additions and 10 deletions

View File

@ -4,6 +4,7 @@ API Key management endpoints.
"""
import logging
from datetime import datetime
from typing import Annotated
from fastapi import APIRouter, Depends, HTTPException, status
@ -26,6 +27,9 @@ class ApiKeyResponse(BaseModel):
key: str = Field(..., description="API key value")
name: str = Field(..., description="API key name")
is_active: bool = Field(..., description="Whether the key is active")
expires_at: str | None = Field(default=None, description="Expiration time")
allowed_ips: list[str] | None = Field(default=None, description="Optional client IP allowlist")
rate_limit_qpm: int | None = Field(default=60, description="Per-minute quota")
created_at: str = Field(..., description="Creation time")
updated_at: str = Field(..., description="Last update time")
@ -42,6 +46,9 @@ class CreateApiKeyRequest(BaseModel):
name: str = Field(..., description="API key name/description")
key: str | None = Field(default=None, description="Custom API key (auto-generated if not provided)")
expires_at: datetime | None = Field(default=None, description="Expiration time; null means never expires")
allowed_ips: list[str] | None = Field(default=None, description="Optional client IP allowlist")
rate_limit_qpm: int | None = Field(default=60, ge=1, le=60000, description="Per-minute quota")
class ToggleApiKeyRequest(BaseModel):
@ -57,6 +64,9 @@ def api_key_to_response(api_key: ApiKey) -> ApiKeyResponse:
key=api_key.key,
name=api_key.name,
is_active=api_key.is_active,
expires_at=api_key.expires_at.isoformat() if api_key.expires_at else None,
allowed_ips=api_key.allowed_ips,
rate_limit_qpm=api_key.rate_limit_qpm,
created_at=api_key.created_at.isoformat(),
updated_at=api_key.updated_at.isoformat(),
)
@ -94,6 +104,9 @@ async def create_api_key(
key=key_value,
name=request.name,
is_active=True,
expires_at=request.expires_at,
allowed_ips=request.allowed_ips,
rate_limit_qpm=request.rate_limit_qpm,
)
api_key = await service.create_key(session, key_create)

View File

@ -294,6 +294,13 @@ class ApiKey(SQLModel, table=True):
key: str = Field(..., description="API Key (unique)", unique=True, index=True)
name: str = Field(..., description="Key name/description for identification")
is_active: bool = Field(default=True, description="Whether the key is active")
expires_at: datetime | None = Field(default=None, description="Expiration time; null means never expires")
allowed_ips: list[str] | None = Field(
default=None,
sa_column=Column("allowed_ips", JSON, nullable=True),
description="Optional IP allowlist for this key",
)
rate_limit_qpm: int | None = Field(default=60, description="Per-minute quota for this key")
created_at: datetime = Field(default_factory=datetime.utcnow, description="Creation time")
updated_at: datetime = Field(default_factory=datetime.utcnow, description="Last update time")
@ -304,6 +311,9 @@ class ApiKeyCreate(SQLModel):
key: str
name: str
is_active: bool = True
expires_at: datetime | None = None
allowed_ips: list[str] | None = None
rate_limit_qpm: int | None = 60
class TemplateVersionStatus(str, Enum):

View File

@ -3,9 +3,13 @@ API Key management service.
[AC-AISVC-50] Lightweight authentication with in-memory cache.
"""
from __future__ import annotations
import logging
import secrets
from datetime import datetime
from collections import deque
from dataclasses import dataclass, field
from datetime import datetime, timedelta
from typing import Optional
from sqlalchemy import select
@ -16,6 +20,25 @@ from app.models.entities import ApiKey, ApiKeyCreate
logger = logging.getLogger(__name__)
@dataclass
class CachedApiKeyMeta:
"""Cached metadata for API key policy checks."""
is_active: bool
expires_at: datetime | None
allowed_ips: set[str] = field(default_factory=set)
rate_limit_qpm: int = 60
@dataclass
class ValidationResult:
"""Validation output for middleware auth + policy checks."""
ok: bool
reason: str | None = None
rate_limit_qpm: int = 60
class ApiKeyService:
"""
[AC-AISVC-50] API Key management service.
@ -28,6 +51,8 @@ class ApiKeyService:
def __init__(self):
self._keys_cache: set[str] = set()
self._key_meta: dict[str, CachedApiKeyMeta] = {}
self._rate_buckets: dict[str, deque[datetime]] = {}
self._initialized: bool = False
async def initialize(self, session: AsyncSession) -> None:
@ -35,15 +60,50 @@ class ApiKeyService:
Load all active API keys from database into memory.
Should be called on application startup.
"""
result = await session.execute(
select(ApiKey).where(ApiKey.is_active == True)
)
keys = result.scalars().all()
try:
result = await session.execute(
select(ApiKey).where(ApiKey.is_active == True)
)
keys = result.scalars().all()
self._keys_cache = {key.key for key in keys}
self._initialized = True
self._keys_cache = {key.key for key in keys}
self._key_meta = {
key.key: CachedApiKeyMeta(
is_active=key.is_active,
expires_at=key.expires_at,
allowed_ips=set(key.allowed_ips or []),
rate_limit_qpm=key.rate_limit_qpm or 60,
)
for key in keys
}
self._initialized = True
logger.info(f"[AC-AISVC-50] Loaded {len(self._keys_cache)} API keys into memory")
return
except Exception as e:
logger.warning(f"[AC-AISVC-50] Full API key schema load failed, fallback to legacy columns: {e}")
await session.rollback()
logger.info(f"[AC-AISVC-50] Loaded {len(self._keys_cache)} API keys into memory")
# Backward-compat fallback for environments without new columns
try:
result = await session.execute(
select(ApiKey.key, ApiKey.is_active).where(ApiKey.is_active == True)
)
rows = result.all()
self._keys_cache = {row[0] for row in rows}
self._key_meta = {
row[0]: CachedApiKeyMeta(
is_active=bool(row[1]),
expires_at=None,
allowed_ips=set(),
rate_limit_qpm=60,
)
for row in rows
}
self._initialized = True
logger.info(f"[AC-AISVC-50] Loaded {len(self._keys_cache)} API keys in legacy compatibility mode")
except Exception as fallback_error:
self._initialized = False
logger.error(f"[AC-AISVC-50] API key initialization failed in both full/legacy mode: {fallback_error}")
def validate_key(self, key: str) -> bool:
"""
@ -61,6 +121,41 @@ class ApiKeyService:
return key in self._keys_cache
def validate_key_with_context(self, key: str, client_ip: str | None) -> ValidationResult:
"""Validate key and policy constraints: expiration, IP allowlist, and per-minute rate."""
if not self._initialized:
return ValidationResult(ok=False, reason="service_not_initialized")
if key not in self._keys_cache:
return ValidationResult(ok=False, reason="invalid_key")
meta = self._key_meta.get(key)
if not meta or not meta.is_active:
return ValidationResult(ok=False, reason="inactive_key")
now = datetime.utcnow()
if meta.expires_at and now > meta.expires_at:
return ValidationResult(ok=False, reason="expired_key")
if meta.allowed_ips and client_ip and client_ip not in meta.allowed_ips:
return ValidationResult(ok=False, reason="ip_not_allowed")
self._evict_stale_rate_entries(key, now)
bucket = self._rate_buckets.setdefault(key, deque())
limit = meta.rate_limit_qpm or 60
if len(bucket) >= limit:
return ValidationResult(ok=False, reason="rate_limited", rate_limit_qpm=limit)
bucket.append(now)
return ValidationResult(ok=True, rate_limit_qpm=limit)
def _evict_stale_rate_entries(self, key: str, now: datetime) -> None:
"""Keep only requests in the latest 60 seconds for token bucket emulation."""
bucket = self._rate_buckets.setdefault(key, deque())
threshold = now - timedelta(seconds=60)
while bucket and bucket[0] < threshold:
bucket.popleft()
def generate_key(self) -> str:
"""
Generate a new secure API key.
@ -89,6 +184,9 @@ class ApiKeyService:
key=key_create.key,
name=key_create.name,
is_active=key_create.is_active,
expires_at=key_create.expires_at,
allowed_ips=key_create.allowed_ips,
rate_limit_qpm=key_create.rate_limit_qpm or 60,
)
session.add(api_key)
@ -97,6 +195,12 @@ class ApiKeyService:
if api_key.is_active:
self._keys_cache.add(api_key.key)
self._key_meta[api_key.key] = CachedApiKeyMeta(
is_active=api_key.is_active,
expires_at=api_key.expires_at,
allowed_ips=set(api_key.allowed_ips or []),
rate_limit_qpm=api_key.rate_limit_qpm or 60,
)
logger.info(f"[AC-AISVC-50] Created API key: {api_key.name}")
return api_key
@ -108,8 +212,14 @@ class ApiKeyService:
Returns:
The created ApiKey or None if keys already exist
"""
result = await session.execute(select(ApiKey).limit(1))
existing = result.scalar_one_or_none()
try:
result = await session.execute(select(ApiKey).limit(1))
existing = result.scalar_one_or_none()
except Exception as e:
logger.warning(f"[AC-AISVC-50] Full schema query failed in create_default_key, using fallback: {e}")
await session.rollback()
result = await session.execute(select(ApiKey.key).limit(1))
existing = result.scalar_one_or_none()
if existing:
return None
@ -126,6 +236,12 @@ class ApiKeyService:
await session.refresh(api_key)
self._keys_cache.add(api_key.key)
self._key_meta[api_key.key] = CachedApiKeyMeta(
is_active=api_key.is_active,
expires_at=getattr(api_key, 'expires_at', None),
allowed_ips=set(getattr(api_key, 'allowed_ips', []) or []),
rate_limit_qpm=getattr(api_key, 'rate_limit_qpm', 60) or 60,
)
logger.info(f"[AC-AISVC-50] Created default API key: {api_key.key}")
return api_key
@ -165,6 +281,8 @@ class ApiKeyService:
await session.commit()
self._keys_cache.discard(key_value)
self._key_meta.pop(key_value, None)
self._rate_buckets.pop(key_value, None)
logger.info(f"[AC-AISVC-50] Deleted API key: {api_key.name}")
return True
@ -210,8 +328,16 @@ class ApiKeyService:
if is_active:
self._keys_cache.add(api_key.key)
self._key_meta[api_key.key] = CachedApiKeyMeta(
is_active=api_key.is_active,
expires_at=api_key.expires_at,
allowed_ips=set(api_key.allowed_ips or []),
rate_limit_qpm=api_key.rate_limit_qpm or 60,
)
else:
self._keys_cache.discard(api_key.key)
self._key_meta.pop(api_key.key, None)
self._rate_buckets.pop(api_key.key, None)
logger.info(f"[AC-AISVC-50] Toggled API key {api_key.name}: active={is_active}")
return api_key
@ -234,6 +360,8 @@ class ApiKeyService:
Reload all API keys from database into memory.
"""
self._keys_cache.clear()
self._key_meta.clear()
self._rate_buckets.clear()
await self.initialize(session)
logger.info("[AC-AISVC-50] API key cache reloaded")