ai-robot-core/ai-service/app/core/middleware.py

292 lines
11 KiB
Python

"""
Middleware for AI Service.
[AC-AISVC-10, AC-AISVC-12, AC-AISVC-50] X-Tenant-Id header validation, tenant context injection, and API Key authentication.
"""
import logging
import re
import uuid
from collections.abc import Callable
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from app.core.exceptions import ErrorCode, ErrorResponse
from app.core.tenant import clear_tenant_context, set_tenant_context
logger = logging.getLogger(__name__)
TENANT_ID_HEADER = "X-Tenant-Id"
API_KEY_HEADER = "X-API-Key"
ACCEPT_HEADER = "Accept"
SSE_CONTENT_TYPE = "text/event-stream"
REQUEST_ID_HEADER = "X-Request-Id"
# Prompt template protected variable names injected by system/runtime.
# These are reserved for internal orchestration and should not be overridden by user input.
PROMPT_PROTECTED_VARIABLES = {
"available_tools",
"query",
"history",
"internal_protocol",
"output_contract",
}
TENANT_ID_PATTERN = re.compile(r'^[^@]+@ash@\d{4}$')
PATHS_SKIP_API_KEY = {
"/health",
"/ai/health",
"/docs",
"/redoc",
"/openapi.json",
"/favicon.ico",
"/openapi/v1/share/chat",
}
PATHS_SKIP_TENANT = {
"/health",
"/ai/health",
"/favicon.ico",
"/openapi/v1/share/chat",
}
def validate_tenant_id_format(tenant_id: str) -> bool:
"""
[AC-AISVC-10] Validate tenant ID format: name@ash@year
Examples: szmp@ash@2026, abc123@ash@2025
"""
return bool(TENANT_ID_PATTERN.match(tenant_id))
def parse_tenant_id(tenant_id: str) -> tuple[str, str]:
"""
[AC-AISVC-10] Parse tenant ID into name and year.
Returns: (name, year)
"""
parts = tenant_id.split('@')
return parts[0], parts[2]
class ApiKeyMiddleware(BaseHTTPMiddleware):
"""
[AC-AISVC-50] Middleware to validate API Key for all requests.
Features:
- Validates X-API-Key header against in-memory cache
- Skips validation for health/docs endpoints
- Returns 401 for missing or invalid API key
"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
if self._should_skip_api_key(request.url.path):
return await call_next(request)
request_id = request.headers.get(REQUEST_ID_HEADER) or str(uuid.uuid4())
request.state.request_id = request_id
api_key = request.headers.get(API_KEY_HEADER)
if not api_key or not api_key.strip():
logger.warning(
f"[AC-AISVC-50] Missing X-API-Key header for {request.url.path}, request_id={request_id}"
)
response = JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content=ErrorResponse(
code=ErrorCode.UNAUTHORIZED.value,
message="Missing required header: X-API-Key",
).model_dump(exclude_none=True),
)
response.headers[REQUEST_ID_HEADER] = request_id
return response
api_key = api_key.strip()
from app.services.api_key import get_api_key_service
service = get_api_key_service()
if not service._initialized:
logger.warning("[AC-AISVC-50] API key service not initialized, attempting lazy initialization...")
try:
from app.core.database import async_session_maker
async with async_session_maker() as session:
await service.initialize(session)
if service._initialized and len(service._keys_cache) > 0:
logger.info(f"[AC-AISVC-50] API key service lazy initialized with {len(service._keys_cache)} keys")
elif service._initialized and len(service._keys_cache) == 0:
logger.warning("[AC-AISVC-50] API key service initialized but no keys found in database")
else:
logger.error("[AC-AISVC-50] API key service lazy initialization failed")
except Exception as e:
logger.error(f"[AC-AISVC-50] Failed to initialize API key service: {e}")
client_ip = request.client.host if request.client else None
tenant_id = request.headers.get(TENANT_ID_HEADER, "")
validation = service.validate_key_with_context(api_key, client_ip=client_ip)
if not validation.ok:
if validation.reason == "rate_limited":
logger.warning(
f"[AC-AISVC-50] Rate limited: path={request.url.path}, tenant={tenant_id}, "
f"ip={client_ip}, qpm={validation.rate_limit_qpm}, request_id={request_id}"
)
response = JSONResponse(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
content=ErrorResponse(
code=ErrorCode.SERVICE_UNAVAILABLE.value,
message="Rate limit exceeded",
details=[{"reason": "rate_limited", "limit_qpm": validation.rate_limit_qpm}],
).model_dump(exclude_none=True),
)
response.headers[REQUEST_ID_HEADER] = request_id
return response
logger.warning(
f"[AC-AISVC-50] API key validation failed: reason={validation.reason}, "
f"path={request.url.path}, tenant={tenant_id}, ip={client_ip}, request_id={request_id}"
)
response = JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content=ErrorResponse(
code=ErrorCode.UNAUTHORIZED.value,
message="Invalid API key",
details=[{"reason": validation.reason}],
).model_dump(exclude_none=True),
)
response.headers[REQUEST_ID_HEADER] = request_id
return response
response = await call_next(request)
response.headers[REQUEST_ID_HEADER] = request_id
return response
def _should_skip_api_key(self, path: str) -> bool:
"""Check if the path should skip API key validation."""
if path in PATHS_SKIP_API_KEY:
return True
for skip_path in PATHS_SKIP_API_KEY:
if path.startswith(skip_path):
return True
return False
class TenantContextMiddleware(BaseHTTPMiddleware):
"""
[AC-AISVC-10, AC-AISVC-12] Middleware to extract and validate X-Tenant-Id header.
Injects tenant context into request state for downstream processing.
Validates tenant ID format and auto-creates tenant if not exists.
"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
clear_tenant_context()
if self._should_skip_tenant(request.url.path):
return await call_next(request)
tenant_id = request.headers.get(TENANT_ID_HEADER)
if not tenant_id or not tenant_id.strip():
logger.warning("[AC-AISVC-12] Missing or empty X-Tenant-Id header")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=ErrorResponse(
code=ErrorCode.MISSING_TENANT_ID.value,
message="Missing required header: X-Tenant-Id",
).model_dump(exclude_none=True),
)
tenant_id = tenant_id.strip()
if not validate_tenant_id_format(tenant_id):
logger.warning(f"[AC-AISVC-10] Invalid tenant ID format: {tenant_id}")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=ErrorResponse(
code=ErrorCode.INVALID_TENANT_ID.value,
message="Invalid tenant ID format. Expected: name@ash@year (e.g., szmp@ash@2026)",
).model_dump(exclude_none=True),
)
if request.url.path.startswith("/admin/") or request.url.path.startswith("/ai/"):
try:
await self._ensure_tenant_exists(request, tenant_id)
except Exception as e:
logger.error(f"[AC-AISVC-10] Failed to ensure tenant exists: {e}")
set_tenant_context(tenant_id)
request.state.tenant_id = tenant_id
logger.info(f"[AC-AISVC-10] Tenant context set: tenant_id={tenant_id}, path={request.url.path}")
try:
logger.info(f"[MIDDLEWARE] Calling next handler for path={request.url.path}")
response = await call_next(request)
logger.info(f"[MIDDLEWARE] Response received for path={request.url.path}, status={response.status_code}")
except Exception as e:
import traceback
logger.error(f"[MIDDLEWARE] Exception in call_next for path={request.url.path}: {type(e).__name__}: {e}\n{traceback.format_exc()}")
raise
finally:
clear_tenant_context()
return response
def _should_skip_tenant(self, path: str) -> bool:
"""Check if the path should skip tenant validation."""
if path in PATHS_SKIP_TENANT:
return True
for skip_path in PATHS_SKIP_TENANT:
if path.startswith(skip_path):
return True
return False
async def _ensure_tenant_exists(self, request: Request, tenant_id: str) -> None:
"""
[AC-AISVC-10] Ensure tenant exists in database, create if not.
"""
from sqlalchemy import select
from app.core.database import async_session_maker
from app.models.entities import Tenant
name, year = parse_tenant_id(tenant_id)
async with async_session_maker() as session:
stmt = select(Tenant).where(Tenant.tenant_id == tenant_id)
result = await session.execute(stmt)
existing_tenant = result.scalar_one_or_none()
if existing_tenant:
logger.debug(f"[AC-AISVC-10] Tenant already exists: {tenant_id}")
return
new_tenant = Tenant(
tenant_id=tenant_id,
name=name,
year=year,
)
session.add(new_tenant)
await session.commit()
logger.info(f"[AC-AISVC-10] Auto-created new tenant: {tenant_id} (name={name}, year={year})")
def is_sse_request(request: Request) -> bool:
"""
[AC-AISVC-06] Check if the request expects SSE streaming response.
Based on Accept header: text/event-stream indicates SSE mode.
"""
accept_header = request.headers.get(ACCEPT_HEADER, "")
return SSE_CONTENT_TYPE in accept_header
def get_response_mode(request: Request) -> str:
"""
[AC-AISVC-06] Determine response mode based on Accept header.
Returns 'streaming' for SSE, 'json' for regular JSON response.
"""
return "streaming" if is_sse_request(request) else "json"