292 lines
11 KiB
Python
292 lines
11 KiB
Python
"""
|
|
Middleware for AI Service.
|
|
[AC-AISVC-10, AC-AISVC-12, AC-AISVC-50] X-Tenant-Id header validation, tenant context injection, and API Key authentication.
|
|
"""
|
|
|
|
import logging
|
|
import re
|
|
import uuid
|
|
from collections.abc import Callable
|
|
|
|
from fastapi import Request, Response, status
|
|
from fastapi.responses import JSONResponse
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
from app.core.exceptions import ErrorCode, ErrorResponse
|
|
from app.core.tenant import clear_tenant_context, set_tenant_context
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
TENANT_ID_HEADER = "X-Tenant-Id"
|
|
API_KEY_HEADER = "X-API-Key"
|
|
ACCEPT_HEADER = "Accept"
|
|
SSE_CONTENT_TYPE = "text/event-stream"
|
|
REQUEST_ID_HEADER = "X-Request-Id"
|
|
|
|
# Prompt template protected variable names injected by system/runtime.
|
|
# These are reserved for internal orchestration and should not be overridden by user input.
|
|
PROMPT_PROTECTED_VARIABLES = {
|
|
"available_tools",
|
|
"query",
|
|
"history",
|
|
"internal_protocol",
|
|
"output_contract",
|
|
}
|
|
|
|
TENANT_ID_PATTERN = re.compile(r'^[^@]+@ash@\d{4}$')
|
|
|
|
PATHS_SKIP_API_KEY = {
|
|
"/health",
|
|
"/ai/health",
|
|
"/docs",
|
|
"/redoc",
|
|
"/openapi.json",
|
|
"/favicon.ico",
|
|
"/openapi/v1/share/chat",
|
|
}
|
|
|
|
PATHS_SKIP_TENANT = {
|
|
"/health",
|
|
"/ai/health",
|
|
"/favicon.ico",
|
|
"/openapi/v1/share/chat",
|
|
}
|
|
|
|
|
|
def validate_tenant_id_format(tenant_id: str) -> bool:
|
|
"""
|
|
[AC-AISVC-10] Validate tenant ID format: name@ash@year
|
|
Examples: szmp@ash@2026, abc123@ash@2025
|
|
"""
|
|
return bool(TENANT_ID_PATTERN.match(tenant_id))
|
|
|
|
|
|
def parse_tenant_id(tenant_id: str) -> tuple[str, str]:
|
|
"""
|
|
[AC-AISVC-10] Parse tenant ID into name and year.
|
|
Returns: (name, year)
|
|
"""
|
|
parts = tenant_id.split('@')
|
|
return parts[0], parts[2]
|
|
|
|
|
|
class ApiKeyMiddleware(BaseHTTPMiddleware):
|
|
"""
|
|
[AC-AISVC-50] Middleware to validate API Key for all requests.
|
|
|
|
Features:
|
|
- Validates X-API-Key header against in-memory cache
|
|
- Skips validation for health/docs endpoints
|
|
- Returns 401 for missing or invalid API key
|
|
"""
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
if self._should_skip_api_key(request.url.path):
|
|
return await call_next(request)
|
|
|
|
request_id = request.headers.get(REQUEST_ID_HEADER) or str(uuid.uuid4())
|
|
request.state.request_id = request_id
|
|
|
|
api_key = request.headers.get(API_KEY_HEADER)
|
|
|
|
if not api_key or not api_key.strip():
|
|
logger.warning(
|
|
f"[AC-AISVC-50] Missing X-API-Key header for {request.url.path}, request_id={request_id}"
|
|
)
|
|
response = JSONResponse(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
content=ErrorResponse(
|
|
code=ErrorCode.UNAUTHORIZED.value,
|
|
message="Missing required header: X-API-Key",
|
|
).model_dump(exclude_none=True),
|
|
)
|
|
response.headers[REQUEST_ID_HEADER] = request_id
|
|
return response
|
|
|
|
api_key = api_key.strip()
|
|
|
|
from app.services.api_key import get_api_key_service
|
|
service = get_api_key_service()
|
|
|
|
if not service._initialized:
|
|
logger.warning("[AC-AISVC-50] API key service not initialized, attempting lazy initialization...")
|
|
try:
|
|
from app.core.database import async_session_maker
|
|
async with async_session_maker() as session:
|
|
await service.initialize(session)
|
|
if service._initialized and len(service._keys_cache) > 0:
|
|
logger.info(f"[AC-AISVC-50] API key service lazy initialized with {len(service._keys_cache)} keys")
|
|
elif service._initialized and len(service._keys_cache) == 0:
|
|
logger.warning("[AC-AISVC-50] API key service initialized but no keys found in database")
|
|
else:
|
|
logger.error("[AC-AISVC-50] API key service lazy initialization failed")
|
|
except Exception as e:
|
|
logger.error(f"[AC-AISVC-50] Failed to initialize API key service: {e}")
|
|
|
|
client_ip = request.client.host if request.client else None
|
|
tenant_id = request.headers.get(TENANT_ID_HEADER, "")
|
|
|
|
validation = service.validate_key_with_context(api_key, client_ip=client_ip)
|
|
if not validation.ok:
|
|
if validation.reason == "rate_limited":
|
|
logger.warning(
|
|
f"[AC-AISVC-50] Rate limited: path={request.url.path}, tenant={tenant_id}, "
|
|
f"ip={client_ip}, qpm={validation.rate_limit_qpm}, request_id={request_id}"
|
|
)
|
|
response = JSONResponse(
|
|
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
|
content=ErrorResponse(
|
|
code=ErrorCode.SERVICE_UNAVAILABLE.value,
|
|
message="Rate limit exceeded",
|
|
details=[{"reason": "rate_limited", "limit_qpm": validation.rate_limit_qpm}],
|
|
).model_dump(exclude_none=True),
|
|
)
|
|
response.headers[REQUEST_ID_HEADER] = request_id
|
|
return response
|
|
|
|
logger.warning(
|
|
f"[AC-AISVC-50] API key validation failed: reason={validation.reason}, "
|
|
f"path={request.url.path}, tenant={tenant_id}, ip={client_ip}, request_id={request_id}"
|
|
)
|
|
response = JSONResponse(
|
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
content=ErrorResponse(
|
|
code=ErrorCode.UNAUTHORIZED.value,
|
|
message="Invalid API key",
|
|
details=[{"reason": validation.reason}],
|
|
).model_dump(exclude_none=True),
|
|
)
|
|
response.headers[REQUEST_ID_HEADER] = request_id
|
|
return response
|
|
|
|
response = await call_next(request)
|
|
response.headers[REQUEST_ID_HEADER] = request_id
|
|
return response
|
|
|
|
def _should_skip_api_key(self, path: str) -> bool:
|
|
"""Check if the path should skip API key validation."""
|
|
if path in PATHS_SKIP_API_KEY:
|
|
return True
|
|
for skip_path in PATHS_SKIP_API_KEY:
|
|
if path.startswith(skip_path):
|
|
return True
|
|
return False
|
|
|
|
|
|
class TenantContextMiddleware(BaseHTTPMiddleware):
|
|
"""
|
|
[AC-AISVC-10, AC-AISVC-12] Middleware to extract and validate X-Tenant-Id header.
|
|
Injects tenant context into request state for downstream processing.
|
|
Validates tenant ID format and auto-creates tenant if not exists.
|
|
"""
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
clear_tenant_context()
|
|
|
|
if self._should_skip_tenant(request.url.path):
|
|
return await call_next(request)
|
|
|
|
tenant_id = request.headers.get(TENANT_ID_HEADER)
|
|
|
|
if not tenant_id or not tenant_id.strip():
|
|
logger.warning("[AC-AISVC-12] Missing or empty X-Tenant-Id header")
|
|
return JSONResponse(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
content=ErrorResponse(
|
|
code=ErrorCode.MISSING_TENANT_ID.value,
|
|
message="Missing required header: X-Tenant-Id",
|
|
).model_dump(exclude_none=True),
|
|
)
|
|
|
|
tenant_id = tenant_id.strip()
|
|
|
|
if not validate_tenant_id_format(tenant_id):
|
|
logger.warning(f"[AC-AISVC-10] Invalid tenant ID format: {tenant_id}")
|
|
return JSONResponse(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
content=ErrorResponse(
|
|
code=ErrorCode.INVALID_TENANT_ID.value,
|
|
message="Invalid tenant ID format. Expected: name@ash@year (e.g., szmp@ash@2026)",
|
|
).model_dump(exclude_none=True),
|
|
)
|
|
|
|
if request.url.path.startswith("/admin/") or request.url.path.startswith("/ai/"):
|
|
try:
|
|
await self._ensure_tenant_exists(request, tenant_id)
|
|
except Exception as e:
|
|
logger.error(f"[AC-AISVC-10] Failed to ensure tenant exists: {e}")
|
|
|
|
set_tenant_context(tenant_id)
|
|
request.state.tenant_id = tenant_id
|
|
|
|
logger.info(f"[AC-AISVC-10] Tenant context set: tenant_id={tenant_id}, path={request.url.path}")
|
|
|
|
try:
|
|
logger.info(f"[MIDDLEWARE] Calling next handler for path={request.url.path}")
|
|
response = await call_next(request)
|
|
logger.info(f"[MIDDLEWARE] Response received for path={request.url.path}, status={response.status_code}")
|
|
except Exception as e:
|
|
import traceback
|
|
logger.error(f"[MIDDLEWARE] Exception in call_next for path={request.url.path}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
|
|
raise
|
|
finally:
|
|
clear_tenant_context()
|
|
|
|
return response
|
|
|
|
def _should_skip_tenant(self, path: str) -> bool:
|
|
"""Check if the path should skip tenant validation."""
|
|
if path in PATHS_SKIP_TENANT:
|
|
return True
|
|
for skip_path in PATHS_SKIP_TENANT:
|
|
if path.startswith(skip_path):
|
|
return True
|
|
return False
|
|
|
|
async def _ensure_tenant_exists(self, request: Request, tenant_id: str) -> None:
|
|
"""
|
|
[AC-AISVC-10] Ensure tenant exists in database, create if not.
|
|
"""
|
|
from sqlalchemy import select
|
|
|
|
from app.core.database import async_session_maker
|
|
from app.models.entities import Tenant
|
|
|
|
name, year = parse_tenant_id(tenant_id)
|
|
|
|
async with async_session_maker() as session:
|
|
stmt = select(Tenant).where(Tenant.tenant_id == tenant_id)
|
|
result = await session.execute(stmt)
|
|
existing_tenant = result.scalar_one_or_none()
|
|
|
|
if existing_tenant:
|
|
logger.debug(f"[AC-AISVC-10] Tenant already exists: {tenant_id}")
|
|
return
|
|
|
|
new_tenant = Tenant(
|
|
tenant_id=tenant_id,
|
|
name=name,
|
|
year=year,
|
|
)
|
|
session.add(new_tenant)
|
|
await session.commit()
|
|
|
|
logger.info(f"[AC-AISVC-10] Auto-created new tenant: {tenant_id} (name={name}, year={year})")
|
|
|
|
|
|
def is_sse_request(request: Request) -> bool:
|
|
"""
|
|
[AC-AISVC-06] Check if the request expects SSE streaming response.
|
|
Based on Accept header: text/event-stream indicates SSE mode.
|
|
"""
|
|
accept_header = request.headers.get(ACCEPT_HEADER, "")
|
|
return SSE_CONTENT_TYPE in accept_header
|
|
|
|
|
|
def get_response_mode(request: Request) -> str:
|
|
"""
|
|
[AC-AISVC-06] Determine response mode based on Accept header.
|
|
Returns 'streaming' for SSE, 'json' for regular JSON response.
|
|
"""
|
|
return "streaming" if is_sse_request(request) else "json"
|