userbeacon/jwt_auth_backend.py (82 lines of code) (raw):

import logging import jwt from django.contrib.auth.models import User from django.conf import settings from django.contrib.auth import PermissionDenied, authenticate from cryptography.x509 import load_pem_x509_certificate from cryptography.hazmat.backends import default_backend from rest_framework.authentication import BaseAuthentication from rest_framework.exceptions import AuthenticationFailed import traceback logger = logging.getLogger(__name__) class JwtAuth(object): @staticmethod def load_local_public_key(): try: with open(settings.JWT_CERTIFICATE_PATH, "r") as certfile: cert_raw = certfile.read().encode("ASCII") cert = load_pem_x509_certificate(cert_raw, default_backend()) return cert.public_key() except Exception as e: logger.error('Could not read certificate: ' + str(e)) raise @staticmethod def load_remote_public_key(token): jwks_url = settings.JWT_CERTIFICATE_PATH jwks_client = jwt.PyJWKClient(jwks_url) return jwks_client.get_signing_key_from_jwt(token).key @staticmethod def _extract_username(claims): username = claims.get("username") #adfs(deprecated) uses this if username is None: username = claims.get("preferred_username") #keycloak and Azure AD use this if username is None: logger.warning("Could not get username from claims set, expect problems") return username def authenticate(self, request, **credentials): token = credentials.get("token", None) if token: logger.debug("JwtAuth got token {0}".format(token)) if not settings.JWT_CERTIFICATE_PATH.startswith("http"): public_key = self.load_local_public_key() else: public_key = self.load_remote_public_key(token) try: decoded = jwt.decode(token, key=public_key, algorithms=["RS256"], audience=getattr(settings, "JWT_EXPECTED_AUDIENCE", None), issuer=getattr(settings, "JWT_EXPECTED_ISSUER", None)) logger.debug("JwtAuth success") return User( username=self._extract_username(decoded), first_name=decoded.get("first_name"), last_name=decoded.get("family_name"), email=decoded.get("email"), is_staff=True, is_active=True, is_superuser=True #until we have groups added in to the JWT claim ) except jwt.exceptions.DecodeError as e: logger.error("Could not decode provided JWT: {0}".format(e)) raise PermissionDenied() except jwt.exceptions.ExpiredSignatureError: logger.error("Token signature has expired") except jwt.exceptions.InvalidAudienceError: logger.error("Token was for another audience: {0}".format()) except Exception as e: logger.error("Unexpected error decoding JWT: {0}".format(traceback.format_exc(e))) raise PermissionDenied() def get_user(self, token): return self.authenticate(None, token=token) class JwtRestAuth(BaseAuthentication): """ this class is a REST-framework compatible authentication class which calls out to our authentication backend via django.authenticate """ def authenticate(self, request): auth_header = request.META.get("HTTP_AUTHORIZATION", None) if isinstance(auth_header, str) and auth_header.startswith("Bearer "): try: user_model = authenticate(request, token=auth_header[7:]) return user_model, "jwt" except PermissionDenied: raise AuthenticationFailed else: return None #authentication not attempted