privaterelay/middleware.py (137 lines of code) (raw):

import binascii import os import re import time from collections.abc import Callable from datetime import UTC, datetime from django.conf import settings from django.http import HttpRequest, HttpResponse from django.shortcuts import redirect import markus from csp.middleware import CSPMiddleware from whitenoise.middleware import WhiteNoiseMiddleware from privaterelay.utils import glean_logger metrics = markus.get_metrics() # To find all the URL paths that serve HTML which need the CSP nonce: # python manage.py collectstatic # find staticfiles -type f -name 'index.html' CSP_NONCE_COOKIE_PATHS = [ "/", "/contains-tracker-warning/", "/flags/", "/faq/", "/vpn-relay/waitlist/", "/accounts/settings/", "/accounts/profile/", "/accounts/account_inactive/", "/vpn-relay-welcome/", "/phone/waitlist/", "/phone/", "/404/", "/tracker-report/", "/premium/waitlist/", "/premium/", ] class EagerNonceCSPMiddleware(CSPMiddleware): # We need a nonce to use Google Tag Manager with a safe CSP: # https://developers.google.com/tag-platform/security/guides/csp # django-csp only includes the nonce value in the CSP header if the csp_nonce # attribute is accessed: # https://django-csp.readthedocs.io/en/latest/nonce.html # That works for urls served by Django views that access the attribute but it # doesn't work for urls that are served by views which don't access the attribute. # (e.g., Whitenoise) # So, to ensure django-csp includes the nonce value in the CSP header of every # response, we override the default CSPMiddleware with this middleware. If the # request is for one of the HTML urls, this middleware sets the request.csp_nonce # attribute and adds a cookie for the React app to get the nonce value for scripts. def process_request(self, request): if request.path in CSP_NONCE_COOKIE_PATHS: request_nonce = binascii.hexlify(os.urandom(16)).decode("ascii") request._csp_nonce = request_nonce def process_response(self, request, response): response = super().process_response(request, response) if request.path in CSP_NONCE_COOKIE_PATHS: response.set_cookie( "csp_nonce", request._csp_nonce, secure=True, samesite="Strict" ) return response class RedirectRootIfLoggedIn: def __init__(self, get_response): self.get_response = get_response def __call__(self, request): # To prevent showing a flash of the landing page when a user is logged # in, use a server-side redirect to send them to the dashboard, # rather than handling that on the client-side: if request.path == "/" and settings.SESSION_COOKIE_NAME in request.COOKIES: query_string = ( "?" + request.META["QUERY_STRING"] if request.META["QUERY_STRING"] else "" ) return redirect("accounts/profile/" + query_string) response = self.get_response(request) return response class AddDetectedCountryToRequestAndResponseHeaders: def __init__(self, get_response): self.get_response = get_response def __call__(self, request): region_key = "X-Client-Region" region_dict = None if region_key in request.headers: region_dict = request.headers if region_key in request.GET: region_dict = request.GET if not region_dict: return self.get_response(request) country = region_dict.get(region_key) request.country = country response = self.get_response(request) response.country = country return response class ResponseMetrics: re_dockerflow = re.compile(r"/__(version|heartbeat|lbheartbeat)__/?$") def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]) -> None: self.get_response = get_response self.middleware = RelayStaticFilesMiddleware() def __call__(self, request: HttpRequest) -> HttpResponse: if not settings.STATSD_ENABLED: return self.get_response(request) start_time = time.time() response = self.get_response(request) delta = time.time() - start_time view_name = self._get_metric_view_name(request) metrics.timing( "response", value=delta * 1000.0, tags=[ f"status:{response.status_code}", f"view:{view_name}", f"method:{request.method}", ], ) return response def _get_metric_view_name(self, request: HttpRequest) -> str: if request.resolver_match: view = request.resolver_match.func if hasattr(view, "view_class"): # Wrapped with rest_framework.decorators.api_view return f"{view.__module__}.{view.view_class.__name__}" return f"{view.__module__}.{view.__name__}" if match := self.re_dockerflow.match(request.path_info): return f"dockerflow.django.views.{match[1]}" if self.middleware.is_staticfile(request.path_info): return "<static_file>" return "<unknown_view>" class StoreFirstVisit: def __init__(self, get_response): self.get_response = get_response def __call__(self, request): response = self.get_response(request) first_visit = request.COOKIES.get("first_visit") if first_visit is None and not request.user.is_anonymous: response.set_cookie("first_visit", datetime.now(UTC)) return response class RelayStaticFilesMiddleware(WhiteNoiseMiddleware): """Customize WhiteNoiseMiddleware for Relay. The WhiteNoiseMiddleware serves static files and sets headers. In production, the files are read from staticfiles/staticfiles.json, and files with hashes in the name are treated as immutable with 10-year cache timeouts. This class also treats Next.js output files (already hashed) as immutable. """ def immutable_file_test(self, path, url): """ Determine whether given URL represents an immutable file (i.e. a file with a hash of its contents as part of its name) which can therefore be cached forever. All files outputted by next.js are hashed and immutable """ if not url.startswith(self.static_prefix): return False name = url[len(self.static_prefix) :] if name.startswith("_next/static/"): return True else: return super().immutable_file_test(path, url) def is_staticfile(self, path_info: str) -> bool: """ Returns True if this file is served by the middleware. This uses the logic from whitenoise.middleware.WhiteNoiseMiddleware.__call__: https://github.com/evansd/whitenoise/blob/220a98894495d407424e80d85d49227a5cf97e1b/src/whitenoise/middleware.py#L117-L124 """ if self.autorefresh: static_file = self.find_file(path_info) else: static_file = self.files.get(path_info) return static_file is not None class GleanApiAccessMiddleware: def __init__(self, get_response): self.get_response = get_response def __call__(self, request): if request.path.startswith("/api/"): glean_logger().log_api_accessed(request) return self.get_response(request)