""" Tenant middleware for automatic tenant context handling. """ import logging from typing import Optional from fastapi import Request, HTTPException, status from fastapi.responses import JSONResponse import jwt from app.core.config import settings from app.models.tenant import Tenant from app.core.database import get_db logger = logging.getLogger(__name__) class TenantMiddleware: """Middleware for handling tenant context in requests.""" def __init__(self, app): self.app = app async def __call__(self, scope, receive, send): if scope["type"] == "http": request = Request(scope, receive) # Skip tenant processing for certain endpoints if self._should_skip_tenant(request.url.path): await self.app(scope, receive, send) return # Extract tenant context tenant_id = await self._extract_tenant_context(request) if tenant_id: # Add tenant context to request state scope["state"] = getattr(scope, "state", {}) scope["state"]["tenant_id"] = tenant_id # Validate tenant exists and is active if not await self._validate_tenant(tenant_id): response = JSONResponse( status_code=status.HTTP_403_FORBIDDEN, content={"detail": "Invalid or inactive tenant"} ) await response(scope, receive, send) return await self.app(scope, receive, send) else: await self.app(scope, receive, send) def _should_skip_tenant(self, path: str) -> bool: """Check if tenant processing should be skipped for this path.""" skip_paths = [ "/health", "/docs", "/openapi.json", "/auth/login", "/auth/register", "/auth/refresh", "/admin/tenants", # Allow tenant management endpoints "/metrics", "/favicon.ico" ] return any(path.startswith(skip_path) for skip_path in skip_paths) async def _extract_tenant_context(self, request: Request) -> Optional[str]: """Extract tenant context from request.""" # Method 1: From Authorization header (JWT token) tenant_id = await self._extract_from_token(request) if tenant_id: return tenant_id # Method 2: From X-Tenant-ID header tenant_id = request.headers.get("X-Tenant-ID") if tenant_id: return tenant_id # Method 3: From query parameter tenant_id = request.query_params.get("tenant_id") if tenant_id: return tenant_id # Method 4: From subdomain (if configured) tenant_id = await self._extract_from_subdomain(request) if tenant_id: return tenant_id return None async def _extract_from_token(self, request: Request) -> Optional[str]: """Extract tenant ID from JWT token.""" auth_header = request.headers.get("Authorization") if not auth_header or not auth_header.startswith("Bearer "): return None try: token = auth_header.split(" ")[1] payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM]) return payload.get("tenant_id") except (jwt.InvalidTokenError, IndexError, KeyError): return None async def _extract_from_subdomain(self, request: Request) -> Optional[str]: """Extract tenant ID from subdomain.""" host = request.headers.get("host", "") # Check if subdomain-based tenant routing is enabled if not settings.ENABLE_SUBDOMAIN_TENANTS: return None # Extract subdomain (e.g., tenant1.example.com -> tenant1) parts = host.split(".") if len(parts) >= 3: subdomain = parts[0] # Skip common subdomains if subdomain not in ["www", "api", "admin", "app"]: return subdomain return None async def _validate_tenant(self, tenant_id: str) -> bool: """Validate that tenant exists and is active.""" try: # Get database session db = next(get_db()) # Query tenant tenant = db.query(Tenant).filter( Tenant.id == tenant_id, Tenant.status == "active" ).first() if not tenant: logger.warning(f"Invalid or inactive tenant: {tenant_id}") return False return True except Exception as e: logger.error(f"Error validating tenant {tenant_id}: {e}") return False def get_current_tenant(request: Request) -> Optional[str]: """Get current tenant ID from request state.""" return getattr(request.state, "tenant_id", None) def require_tenant(): """Decorator to require tenant context.""" def decorator(func): async def wrapper(*args, request: Request = None, **kwargs): if not request: # Try to find request in args for arg in args: if isinstance(arg, Request): request = arg break if not request: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Request object not found" ) tenant_id = get_current_tenant(request) if not tenant_id: raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Tenant context required" ) return await func(*args, **kwargs) return wrapper return decorator def tenant_aware_query(query, tenant_id: str): """Add tenant filter to database query.""" if hasattr(query.model, 'tenant_id'): return query.filter(query.model.tenant_id == tenant_id) return query def tenant_aware_create(data: dict, tenant_id: str): """Add tenant ID to create data.""" if 'tenant_id' not in data: data['tenant_id'] = tenant_id return data