userbeacon/jwt_auth_backend.py (82 lines of code) (raw):
import logging
import jwt
from django.contrib.auth.models import User
from django.conf import settings
from django.contrib.auth import PermissionDenied, authenticate
from cryptography.x509 import load_pem_x509_certificate
from cryptography.hazmat.backends import default_backend
from rest_framework.authentication import BaseAuthentication
from rest_framework.exceptions import AuthenticationFailed
import traceback
logger = logging.getLogger(__name__)
class JwtAuth(object):
@staticmethod
def load_local_public_key():
try:
with open(settings.JWT_CERTIFICATE_PATH, "r") as certfile:
cert_raw = certfile.read().encode("ASCII")
cert = load_pem_x509_certificate(cert_raw, default_backend())
return cert.public_key()
except Exception as e:
logger.error('Could not read certificate: ' + str(e))
raise
@staticmethod
def load_remote_public_key(token):
jwks_url = settings.JWT_CERTIFICATE_PATH
jwks_client = jwt.PyJWKClient(jwks_url)
return jwks_client.get_signing_key_from_jwt(token).key
@staticmethod
def _extract_username(claims):
username = claims.get("username") #adfs(deprecated) uses this
if username is None:
username = claims.get("preferred_username") #keycloak and Azure AD use this
if username is None:
logger.warning("Could not get username from claims set, expect problems")
return username
def authenticate(self, request, **credentials):
token = credentials.get("token", None)
if token:
logger.debug("JwtAuth got token {0}".format(token))
if not settings.JWT_CERTIFICATE_PATH.startswith("http"):
public_key = self.load_local_public_key()
else:
public_key = self.load_remote_public_key(token)
try:
decoded = jwt.decode(token,
key=public_key,
algorithms=["RS256"],
audience=getattr(settings, "JWT_EXPECTED_AUDIENCE", None),
issuer=getattr(settings, "JWT_EXPECTED_ISSUER", None))
logger.debug("JwtAuth success")
return User(
username=self._extract_username(decoded),
first_name=decoded.get("first_name"),
last_name=decoded.get("family_name"),
email=decoded.get("email"),
is_staff=True,
is_active=True,
is_superuser=True #until we have groups added in to the JWT claim
)
except jwt.exceptions.DecodeError as e:
logger.error("Could not decode provided JWT: {0}".format(e))
raise PermissionDenied()
except jwt.exceptions.ExpiredSignatureError:
logger.error("Token signature has expired")
except jwt.exceptions.InvalidAudienceError:
logger.error("Token was for another audience: {0}".format())
except Exception as e:
logger.error("Unexpected error decoding JWT: {0}".format(traceback.format_exc(e)))
raise PermissionDenied()
def get_user(self, token):
return self.authenticate(None, token=token)
class JwtRestAuth(BaseAuthentication):
"""
this class is a REST-framework compatible authentication class
which calls out to our authentication backend via django.authenticate
"""
def authenticate(self, request):
auth_header = request.META.get("HTTP_AUTHORIZATION", None)
if isinstance(auth_header, str) and auth_header.startswith("Bearer "):
try:
user_model = authenticate(request, token=auth_header[7:])
return user_model, "jwt"
except PermissionDenied:
raise AuthenticationFailed
else:
return None #authentication not attempted