- Implement multi-format document support (PDF, XLSX, CSV, PPTX, TXT, Images) - Add S3-compatible storage service with tenant isolation - Create document organization service with hierarchical folders and tagging - Implement advanced document processing with table/chart extraction - Add batch upload capabilities (up to 50 files) - Create comprehensive document validation and security scanning - Implement automatic metadata extraction and categorization - Add document version control system - Update DEVELOPMENT_PLAN.md to mark Week 2 as completed - Add WEEK2_COMPLETION_SUMMARY.md with detailed implementation notes - All tests passing (6/6) - 100% success rate
188 lines
6.4 KiB
Python
188 lines
6.4 KiB
Python
"""
|
|
Tenant middleware for automatic tenant context handling.
|
|
"""
|
|
import logging
|
|
from typing import Optional
|
|
from fastapi import Request, HTTPException, status
|
|
from fastapi.responses import JSONResponse
|
|
import jwt
|
|
|
|
from app.core.config import settings
|
|
from app.models.tenant import Tenant
|
|
from app.core.database import get_db
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
class TenantMiddleware:
|
|
"""Middleware for handling tenant context in requests."""
|
|
|
|
def __init__(self, app):
|
|
self.app = app
|
|
|
|
async def __call__(self, scope, receive, send):
|
|
if scope["type"] == "http":
|
|
request = Request(scope, receive)
|
|
|
|
# Skip tenant processing for certain endpoints
|
|
if self._should_skip_tenant(request.url.path):
|
|
await self.app(scope, receive, send)
|
|
return
|
|
|
|
# Extract tenant context
|
|
tenant_id = await self._extract_tenant_context(request)
|
|
|
|
if tenant_id:
|
|
# Add tenant context to request state
|
|
scope["state"] = getattr(scope, "state", {})
|
|
scope["state"]["tenant_id"] = tenant_id
|
|
|
|
# Validate tenant exists and is active
|
|
if not await self._validate_tenant(tenant_id):
|
|
response = JSONResponse(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
content={"detail": "Invalid or inactive tenant"}
|
|
)
|
|
await response(scope, receive, send)
|
|
return
|
|
|
|
await self.app(scope, receive, send)
|
|
else:
|
|
await self.app(scope, receive, send)
|
|
|
|
def _should_skip_tenant(self, path: str) -> bool:
|
|
"""Check if tenant processing should be skipped for this path."""
|
|
skip_paths = [
|
|
"/health",
|
|
"/docs",
|
|
"/openapi.json",
|
|
"/auth/login",
|
|
"/auth/register",
|
|
"/auth/refresh",
|
|
"/admin/tenants", # Allow tenant management endpoints
|
|
"/metrics",
|
|
"/favicon.ico"
|
|
]
|
|
|
|
return any(path.startswith(skip_path) for skip_path in skip_paths)
|
|
|
|
async def _extract_tenant_context(self, request: Request) -> Optional[str]:
|
|
"""Extract tenant context from request."""
|
|
# Method 1: From Authorization header (JWT token)
|
|
tenant_id = await self._extract_from_token(request)
|
|
if tenant_id:
|
|
return tenant_id
|
|
|
|
# Method 2: From X-Tenant-ID header
|
|
tenant_id = request.headers.get("X-Tenant-ID")
|
|
if tenant_id:
|
|
return tenant_id
|
|
|
|
# Method 3: From query parameter
|
|
tenant_id = request.query_params.get("tenant_id")
|
|
if tenant_id:
|
|
return tenant_id
|
|
|
|
# Method 4: From subdomain (if configured)
|
|
tenant_id = await self._extract_from_subdomain(request)
|
|
if tenant_id:
|
|
return tenant_id
|
|
|
|
return None
|
|
|
|
async def _extract_from_token(self, request: Request) -> Optional[str]:
|
|
"""Extract tenant ID from JWT token."""
|
|
auth_header = request.headers.get("Authorization")
|
|
if not auth_header or not auth_header.startswith("Bearer "):
|
|
return None
|
|
|
|
try:
|
|
token = auth_header.split(" ")[1]
|
|
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
|
|
return payload.get("tenant_id")
|
|
except (jwt.InvalidTokenError, IndexError, KeyError):
|
|
return None
|
|
|
|
async def _extract_from_subdomain(self, request: Request) -> Optional[str]:
|
|
"""Extract tenant ID from subdomain."""
|
|
host = request.headers.get("host", "")
|
|
|
|
# Check if subdomain-based tenant routing is enabled
|
|
if not settings.ENABLE_SUBDOMAIN_TENANTS:
|
|
return None
|
|
|
|
# Extract subdomain (e.g., tenant1.example.com -> tenant1)
|
|
parts = host.split(".")
|
|
if len(parts) >= 3:
|
|
subdomain = parts[0]
|
|
# Skip common subdomains
|
|
if subdomain not in ["www", "api", "admin", "app"]:
|
|
return subdomain
|
|
|
|
return None
|
|
|
|
async def _validate_tenant(self, tenant_id: str) -> bool:
|
|
"""Validate that tenant exists and is active."""
|
|
try:
|
|
# Get database session
|
|
db = next(get_db())
|
|
|
|
# Query tenant
|
|
tenant = db.query(Tenant).filter(
|
|
Tenant.id == tenant_id,
|
|
Tenant.status == "active"
|
|
).first()
|
|
|
|
if not tenant:
|
|
logger.warning(f"Invalid or inactive tenant: {tenant_id}")
|
|
return False
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logger.error(f"Error validating tenant {tenant_id}: {e}")
|
|
return False
|
|
|
|
def get_current_tenant(request: Request) -> Optional[str]:
|
|
"""Get current tenant ID from request state."""
|
|
return getattr(request.state, "tenant_id", None)
|
|
|
|
def require_tenant():
|
|
"""Decorator to require tenant context."""
|
|
def decorator(func):
|
|
async def wrapper(*args, request: Request = None, **kwargs):
|
|
if not request:
|
|
# Try to find request in args
|
|
for arg in args:
|
|
if isinstance(arg, Request):
|
|
request = arg
|
|
break
|
|
|
|
if not request:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Request object not found"
|
|
)
|
|
|
|
tenant_id = get_current_tenant(request)
|
|
if not tenant_id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Tenant context required"
|
|
)
|
|
|
|
return await func(*args, **kwargs)
|
|
return wrapper
|
|
return decorator
|
|
|
|
def tenant_aware_query(query, tenant_id: str):
|
|
"""Add tenant filter to database query."""
|
|
if hasattr(query.model, 'tenant_id'):
|
|
return query.filter(query.model.tenant_id == tenant_id)
|
|
return query
|
|
|
|
def tenant_aware_create(data: dict, tenant_id: str):
|
|
"""Add tenant ID to create data."""
|
|
if 'tenant_id' not in data:
|
|
data['tenant_id'] = tenant_id
|
|
return data
|