""" Middleware components for the Virtual Board Member AI System. """ import time from typing import Callable from fastapi import Request, Response from starlette.middleware.base import BaseHTTPMiddleware from prometheus_client import Counter, Histogram import structlog from app.core.config import settings logger = structlog.get_logger() # Prometheus metrics REQUEST_COUNT = Counter( "http_requests_total", "Total HTTP requests", ["method", "endpoint", "status"] ) REQUEST_LATENCY = Histogram( "http_request_duration_seconds", "HTTP request latency", ["method", "endpoint"] ) class RequestLoggingMiddleware(BaseHTTPMiddleware): """Middleware for logging HTTP requests.""" async def dispatch(self, request: Request, call_next: Callable) -> Response: start_time = time.time() # Log request logger.info( "HTTP request started", method=request.method, url=str(request.url), client_ip=request.client.host if request.client else None, user_agent=request.headers.get("user-agent"), ) # Process request response = await call_next(request) # Calculate duration duration = time.time() - start_time # Log response logger.info( "HTTP request completed", method=request.method, url=str(request.url), status_code=response.status_code, duration=duration, ) return response class PrometheusMiddleware(BaseHTTPMiddleware): """Middleware for Prometheus metrics.""" async def dispatch(self, request: Request, call_next: Callable) -> Response: start_time = time.time() # Process request response = await call_next(request) # Calculate duration duration = time.time() - start_time # Extract endpoint (remove query parameters and path parameters) endpoint = request.url.path # Record metrics REQUEST_COUNT.labels( method=request.method, endpoint=endpoint, status=response.status_code ).inc() REQUEST_LATENCY.labels( method=request.method, endpoint=endpoint ).observe(duration) return response class SecurityHeadersMiddleware(BaseHTTPMiddleware): """Middleware for adding security headers.""" async def dispatch(self, request: Request, call_next: Callable) -> Response: response = await call_next(request) # Add security headers response.headers["X-Content-Type-Options"] = "nosniff" response.headers["X-Frame-Options"] = "DENY" response.headers["X-XSS-Protection"] = "1; mode=block" response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin" response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()" # Add CSP header in production if settings.is_production: response.headers["Content-Security-Policy"] = ( "default-src 'self'; " "script-src 'self' 'unsafe-inline' 'unsafe-eval'; " "style-src 'self' 'unsafe-inline'; " "img-src 'self' data: https:; " "font-src 'self' data:; " "connect-src 'self' https:; " "frame-ancestors 'none';" ) return response class RateLimitMiddleware(BaseHTTPMiddleware): """Middleware for rate limiting.""" def __init__(self, app, requests_per_minute: int = 100): super().__init__(app) self.requests_per_minute = requests_per_minute self.request_counts = {} async def dispatch(self, request: Request, call_next: Callable) -> Response: client_ip = request.client.host if request.client else "unknown" current_time = time.time() # Clean old entries self._clean_old_entries(current_time) # Check rate limit if not self._check_rate_limit(client_ip, current_time): logger.warning( "Rate limit exceeded", client_ip=client_ip, requests_per_minute=self.requests_per_minute ) return Response( content="Rate limit exceeded", status_code=429, headers={"Retry-After": "60"} ) # Process request response = await call_next(request) # Record request self._record_request(client_ip, current_time) return response def _clean_old_entries(self, current_time: float) -> None: """Remove entries older than 1 minute.""" cutoff_time = current_time - 60 for client_ip in list(self.request_counts.keys()): self.request_counts[client_ip] = [ timestamp for timestamp in self.request_counts[client_ip] if timestamp > cutoff_time ] if not self.request_counts[client_ip]: del self.request_counts[client_ip] def _check_rate_limit(self, client_ip: str, current_time: float) -> bool: """Check if client has exceeded rate limit.""" if client_ip not in self.request_counts: return True requests_in_window = len([ timestamp for timestamp in self.request_counts[client_ip] if current_time - timestamp < 60 ]) return requests_in_window < self.requests_per_minute def _record_request(self, client_ip: str, current_time: float) -> None: """Record a request for the client.""" if client_ip not in self.request_counts: self.request_counts[client_ip] = [] self.request_counts[client_ip].append(current_time) class CORSMiddleware(BaseHTTPMiddleware): """Custom CORS middleware.""" async def dispatch(self, request: Request, call_next: Callable) -> Response: response = await call_next(request) # Add CORS headers origin = request.headers.get("origin") if origin and origin in settings.ALLOWED_HOSTS: response.headers["Access-Control-Allow-Origin"] = origin else: response.headers["Access-Control-Allow-Origin"] = "*" response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS" response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization" response.headers["Access-Control-Allow-Credentials"] = "true" return response