ai-robot-core/ai-service/app/core/middleware.py

75 lines
2.4 KiB
Python

"""
Middleware for AI Service.
[AC-AISVC-10, AC-AISVC-12] X-Tenant-Id header validation and tenant context injection.
"""
import logging
from typing import Callable
from fastapi import Request, Response, status
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
from app.core.exceptions import ErrorCode, ErrorResponse, MissingTenantIdException
from app.core.tenant import clear_tenant_context, set_tenant_context
logger = logging.getLogger(__name__)
TENANT_ID_HEADER = "X-Tenant-Id"
ACCEPT_HEADER = "Accept"
SSE_CONTENT_TYPE = "text/event-stream"
class TenantContextMiddleware(BaseHTTPMiddleware):
"""
[AC-AISVC-10, AC-AISVC-12] Middleware to extract and validate X-Tenant-Id header.
Injects tenant context into request state for downstream processing.
"""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
clear_tenant_context()
if request.url.path == "/ai/health":
return await call_next(request)
tenant_id = request.headers.get(TENANT_ID_HEADER)
if not tenant_id or not tenant_id.strip():
logger.warning("[AC-AISVC-12] Missing or empty X-Tenant-Id header")
return JSONResponse(
status_code=status.HTTP_400_BAD_REQUEST,
content=ErrorResponse(
code=ErrorCode.MISSING_TENANT_ID.value,
message="Missing required header: X-Tenant-Id",
).model_dump(exclude_none=True),
)
set_tenant_context(tenant_id.strip())
request.state.tenant_id = tenant_id.strip()
logger.info(f"[AC-AISVC-10] Tenant context set: tenant_id={tenant_id.strip()}")
try:
response = await call_next(request)
finally:
clear_tenant_context()
return response
def is_sse_request(request: Request) -> bool:
"""
[AC-AISVC-06] Check if the request expects SSE streaming response.
Based on Accept header: text/event-stream indicates SSE mode.
"""
accept_header = request.headers.get(ACCEPT_HEADER, "")
return SSE_CONTENT_TYPE in accept_header
def get_response_mode(request: Request) -> str:
"""
[AC-AISVC-06] Determine response mode based on Accept header.
Returns 'streaming' for SSE, 'json' for regular JSON response.
"""
return "streaming" if is_sse_request(request) else "json"