Initial commit: Virtual Board Member AI System foundation
This commit is contained in:
204
app/core/middleware.py
Normal file
204
app/core/middleware.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""
|
||||
Middleware components for the Virtual Board Member AI System.
|
||||
"""
|
||||
|
||||
import time
|
||||
from typing import Callable
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from prometheus_client import Counter, Histogram
|
||||
import structlog
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
logger = structlog.get_logger()
|
||||
|
||||
# Prometheus metrics
|
||||
REQUEST_COUNT = Counter(
|
||||
"http_requests_total",
|
||||
"Total HTTP requests",
|
||||
["method", "endpoint", "status"]
|
||||
)
|
||||
REQUEST_LATENCY = Histogram(
|
||||
"http_request_duration_seconds",
|
||||
"HTTP request latency",
|
||||
["method", "endpoint"]
|
||||
)
|
||||
|
||||
|
||||
class RequestLoggingMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware for logging HTTP requests."""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
start_time = time.time()
|
||||
|
||||
# Log request
|
||||
logger.info(
|
||||
"HTTP request started",
|
||||
method=request.method,
|
||||
url=str(request.url),
|
||||
client_ip=request.client.host if request.client else None,
|
||||
user_agent=request.headers.get("user-agent"),
|
||||
)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Calculate duration
|
||||
duration = time.time() - start_time
|
||||
|
||||
# Log response
|
||||
logger.info(
|
||||
"HTTP request completed",
|
||||
method=request.method,
|
||||
url=str(request.url),
|
||||
status_code=response.status_code,
|
||||
duration=duration,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class PrometheusMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware for Prometheus metrics."""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
start_time = time.time()
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Calculate duration
|
||||
duration = time.time() - start_time
|
||||
|
||||
# Extract endpoint (remove query parameters and path parameters)
|
||||
endpoint = request.url.path
|
||||
|
||||
# Record metrics
|
||||
REQUEST_COUNT.labels(
|
||||
method=request.method,
|
||||
endpoint=endpoint,
|
||||
status=response.status_code
|
||||
).inc()
|
||||
|
||||
REQUEST_LATENCY.labels(
|
||||
method=request.method,
|
||||
endpoint=endpoint
|
||||
).observe(duration)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class SecurityHeadersMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware for adding security headers."""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
response = await call_next(request)
|
||||
|
||||
# Add security headers
|
||||
response.headers["X-Content-Type-Options"] = "nosniff"
|
||||
response.headers["X-Frame-Options"] = "DENY"
|
||||
response.headers["X-XSS-Protection"] = "1; mode=block"
|
||||
response.headers["Referrer-Policy"] = "strict-origin-when-cross-origin"
|
||||
response.headers["Permissions-Policy"] = "geolocation=(), microphone=(), camera=()"
|
||||
|
||||
# Add CSP header in production
|
||||
if settings.is_production:
|
||||
response.headers["Content-Security-Policy"] = (
|
||||
"default-src 'self'; "
|
||||
"script-src 'self' 'unsafe-inline' 'unsafe-eval'; "
|
||||
"style-src 'self' 'unsafe-inline'; "
|
||||
"img-src 'self' data: https:; "
|
||||
"font-src 'self' data:; "
|
||||
"connect-src 'self' https:; "
|
||||
"frame-ancestors 'none';"
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class RateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware for rate limiting."""
|
||||
|
||||
def __init__(self, app, requests_per_minute: int = 100):
|
||||
super().__init__(app)
|
||||
self.requests_per_minute = requests_per_minute
|
||||
self.request_counts = {}
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
client_ip = request.client.host if request.client else "unknown"
|
||||
current_time = time.time()
|
||||
|
||||
# Clean old entries
|
||||
self._clean_old_entries(current_time)
|
||||
|
||||
# Check rate limit
|
||||
if not self._check_rate_limit(client_ip, current_time):
|
||||
logger.warning(
|
||||
"Rate limit exceeded",
|
||||
client_ip=client_ip,
|
||||
requests_per_minute=self.requests_per_minute
|
||||
)
|
||||
return Response(
|
||||
content="Rate limit exceeded",
|
||||
status_code=429,
|
||||
headers={"Retry-After": "60"}
|
||||
)
|
||||
|
||||
# Process request
|
||||
response = await call_next(request)
|
||||
|
||||
# Record request
|
||||
self._record_request(client_ip, current_time)
|
||||
|
||||
return response
|
||||
|
||||
def _clean_old_entries(self, current_time: float) -> None:
|
||||
"""Remove entries older than 1 minute."""
|
||||
cutoff_time = current_time - 60
|
||||
for client_ip in list(self.request_counts.keys()):
|
||||
self.request_counts[client_ip] = [
|
||||
timestamp for timestamp in self.request_counts[client_ip]
|
||||
if timestamp > cutoff_time
|
||||
]
|
||||
if not self.request_counts[client_ip]:
|
||||
del self.request_counts[client_ip]
|
||||
|
||||
def _check_rate_limit(self, client_ip: str, current_time: float) -> bool:
|
||||
"""Check if client has exceeded rate limit."""
|
||||
if client_ip not in self.request_counts:
|
||||
return True
|
||||
|
||||
requests_in_window = len([
|
||||
timestamp for timestamp in self.request_counts[client_ip]
|
||||
if current_time - timestamp < 60
|
||||
])
|
||||
|
||||
return requests_in_window < self.requests_per_minute
|
||||
|
||||
def _record_request(self, client_ip: str, current_time: float) -> None:
|
||||
"""Record a request for the client."""
|
||||
if client_ip not in self.request_counts:
|
||||
self.request_counts[client_ip] = []
|
||||
|
||||
self.request_counts[client_ip].append(current_time)
|
||||
|
||||
|
||||
class CORSMiddleware(BaseHTTPMiddleware):
|
||||
"""Custom CORS middleware."""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
response = await call_next(request)
|
||||
|
||||
# Add CORS headers
|
||||
origin = request.headers.get("origin")
|
||||
if origin and origin in settings.ALLOWED_HOSTS:
|
||||
response.headers["Access-Control-Allow-Origin"] = origin
|
||||
else:
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
|
||||
response.headers["Access-Control-Allow-Methods"] = "GET, POST, PUT, DELETE, OPTIONS"
|
||||
response.headers["Access-Control-Allow-Headers"] = "Content-Type, Authorization"
|
||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
|
||||
return response
|
||||
Reference in New Issue
Block a user