Files
virtual_board_member/app/core/middleware.py
2025-08-07 16:11:14 -04:00

205 lines
6.6 KiB
Python

"""
Middleware components for the Virtual Board Member AI System.
"""
import time
from typing import Callable
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from prometheus_client import Counter, Histogram
import structlog
from app.core.config import settings
logger = structlog.get_logger()
# Prometheus metrics
REQUEST_COUNT = Counter(
"http_requests_total",
"Total HTTP requests",
["method", "endpoint", "status"]
)
REQUEST_LATENCY = Histogram(
"http_request_duration_seconds",
"HTTP request latency",
["method", "endpoint"]
)
class RequestLoggingMiddleware(BaseHTTPMiddleware):
"""Middleware for logging HTTP requests."""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
start_time = time.time()
# Log request
logger.info(
"HTTP request started",
method=request.method,
url=str(request.url),
client_ip=request.client.host if request.client else None,
user_agent=request.headers.get("user-agent"),
)
# Process request
response = await call_next(request)
# Calculate duration
duration = time.time() - start_time
# Log response
logger.info(
"HTTP request completed",
method=request.method,
url=str(request.url),
status_code=response.status_code,
duration=duration,
)
return response
class PrometheusMiddleware(BaseHTTPMiddleware):
"""Middleware for Prometheus metrics."""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
start_time = time.time()
# Process request
response = await call_next(request)
# Calculate duration
duration = time.time() - start_time
# Extract endpoint (remove query parameters and path parameters)
endpoint = request.url.path
# Record metrics
REQUEST_COUNT.labels(
method=request.method,
endpoint=endpoint,
status=response.status_code
).inc()
REQUEST_LATENCY.labels(
method=request.method,
endpoint=endpoint
).observe(duration)
return response
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
"""Middleware for adding security headers."""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
response = await call_next(request)
# Add security headers
response.headers["X-Content-Type-Options"] = "nosniff"
response.headers["X-Frame-Options"] = "DENY"
response.headers["X-XSS-Protection"] = "1; mode=block"
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
# Add CSP header in production
if settings.is_production:
response.headers["Content-Security-Policy"] = (
"default-src 'self'; "
"script-src 'self' 'unsafe-inline' 'unsafe-eval'; "
"style-src 'self' 'unsafe-inline'; "
"img-src 'self' data: https:; "
"font-src 'self' data:; "
"connect-src 'self' https:; "
"frame-ancestors 'none';"
)
return response
class RateLimitMiddleware(BaseHTTPMiddleware):
"""Middleware for rate limiting."""
def __init__(self, app, requests_per_minute: int = 100):
super().__init__(app)
self.requests_per_minute = requests_per_minute
self.request_counts = {}
async def dispatch(self, request: Request, call_next: Callable) -> Response:
client_ip = request.client.host if request.client else "unknown"
current_time = time.time()
# Clean old entries
self._clean_old_entries(current_time)
# Check rate limit
if not self._check_rate_limit(client_ip, current_time):
logger.warning(
"Rate limit exceeded",
client_ip=client_ip,
requests_per_minute=self.requests_per_minute
)
return Response(
content="Rate limit exceeded",
status_code=429,
headers={"Retry-After": "60"}
)
# Process request
response = await call_next(request)
# Record request
self._record_request(client_ip, current_time)
return response
def _clean_old_entries(self, current_time: float) -> None:
"""Remove entries older than 1 minute."""
cutoff_time = current_time - 60
for client_ip in list(self.request_counts.keys()):
self.request_counts[client_ip] = [
timestamp for timestamp in self.request_counts[client_ip]
if timestamp > cutoff_time
]
if not self.request_counts[client_ip]:
del self.request_counts[client_ip]
def _check_rate_limit(self, client_ip: str, current_time: float) -> bool:
"""Check if client has exceeded rate limit."""
if client_ip not in self.request_counts:
return True
requests_in_window = len([
timestamp for timestamp in self.request_counts[client_ip]
if current_time - timestamp < 60
])
return requests_in_window < self.requests_per_minute
def _record_request(self, client_ip: str, current_time: float) -> None:
"""Record a request for the client."""
if client_ip not in self.request_counts:
self.request_counts[client_ip] = []
self.request_counts[client_ip].append(current_time)
class CORSMiddleware(BaseHTTPMiddleware):
"""Custom CORS middleware."""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
response = await call_next(request)
# Add CORS headers
origin = request.headers.get("origin")
if origin and origin in settings.ALLOWED_HOSTS:
response.headers["Access-Control-Allow-Origin"] = origin
else:
response.headers["Access-Control-Allow-Origin"] = "*"
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
response.headers["Access-Control-Allow-Credentials"] = "true"
return response