75 lines
2.4 KiB
Python
75 lines
2.4 KiB
Python
"""
|
|
Middleware for AI Service.
|
|
[AC-AISVC-10, AC-AISVC-12] X-Tenant-Id header validation and tenant context injection.
|
|
"""
|
|
|
|
import logging
|
|
from typing import Callable
|
|
|
|
from fastapi import Request, Response, status
|
|
from fastapi.responses import JSONResponse
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
from app.core.exceptions import ErrorCode, ErrorResponse, MissingTenantIdException
|
|
from app.core.tenant import clear_tenant_context, set_tenant_context
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
TENANT_ID_HEADER = "X-Tenant-Id"
|
|
ACCEPT_HEADER = "Accept"
|
|
SSE_CONTENT_TYPE = "text/event-stream"
|
|
|
|
|
|
class TenantContextMiddleware(BaseHTTPMiddleware):
|
|
"""
|
|
[AC-AISVC-10, AC-AISVC-12] Middleware to extract and validate X-Tenant-Id header.
|
|
Injects tenant context into request state for downstream processing.
|
|
"""
|
|
|
|
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
|
clear_tenant_context()
|
|
|
|
if request.url.path == "/ai/health":
|
|
return await call_next(request)
|
|
|
|
tenant_id = request.headers.get(TENANT_ID_HEADER)
|
|
|
|
if not tenant_id or not tenant_id.strip():
|
|
logger.warning("[AC-AISVC-12] Missing or empty X-Tenant-Id header")
|
|
return JSONResponse(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
content=ErrorResponse(
|
|
code=ErrorCode.MISSING_TENANT_ID.value,
|
|
message="Missing required header: X-Tenant-Id",
|
|
).model_dump(exclude_none=True),
|
|
)
|
|
|
|
set_tenant_context(tenant_id.strip())
|
|
request.state.tenant_id = tenant_id.strip()
|
|
|
|
logger.info(f"[AC-AISVC-10] Tenant context set: tenant_id={tenant_id.strip()}")
|
|
|
|
try:
|
|
response = await call_next(request)
|
|
finally:
|
|
clear_tenant_context()
|
|
|
|
return response
|
|
|
|
|
|
def is_sse_request(request: Request) -> bool:
|
|
"""
|
|
[AC-AISVC-06] Check if the request expects SSE streaming response.
|
|
Based on Accept header: text/event-stream indicates SSE mode.
|
|
"""
|
|
accept_header = request.headers.get(ACCEPT_HEADER, "")
|
|
return SSE_CONTENT_TYPE in accept_header
|
|
|
|
|
|
def get_response_mode(request: Request) -> str:
|
|
"""
|
|
[AC-AISVC-06] Determine response mode based on Accept header.
|
|
Returns 'streaming' for SSE, 'json' for regular JSON response.
|
|
"""
|
|
return "streaming" if is_sse_request(request) else "json"
|