pathology/shared_libs/iap_auth_lib/auth.py (52 lines of code) (raw):
# Copyright 2023 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Validates the JWT header generated from IAP.
See https://cloud.google.com/iap/docs/signed-headers-howto and
https://cloud.google.com/python/docs/getting-started/authenticate-users.
"""
import functools
import http
from typing import Any, Callable, Mapping, Union
from absl import flags
from absl import logging
import flask
from google.auth import exceptions
from google.auth.transport import requests
from google.oauth2 import id_token
from pathology.shared_libs.flags import secret_flag_utils
VALIDATE_IAP_FLG = flags.DEFINE_bool(
'validate_iap',
secret_flag_utils.get_bool_secret_or_env('VALIDATE_IAP', True),
'Whether to validate IAP.',
)
JWT_AUD_FLG = flags.DEFINE_string(
'jwt_audience',
# Backend service default/dpas-orchestrator-service
# Project dpas-digipat in organization cloudflyer.info
secret_flag_utils.get_secret_or_env(
'JWT_AUDIENCE',
None
),
'JWT audience of this backend service.',
)
# The URL for public key used in JWT. Please see
# https://cloud.google.com/iap/docs/signed-headers-howto#verifying_the_jwt_header
# for more details.
_CERTS_URL = 'https://www.gstatic.com/iap/verify/public_key'
# The header to store signed JWT. Please see
# https://cloud.google.com/python/docs/getting-started/authenticate-users
# for more details.
IAP_JWT_HEADER = 'X-Goog-IAP-JWT-Assertion'
def _get_flask_headers() -> Mapping[str, str]:
return flask.request.headers
def validate_iap(request_handler: Callable[..., Any]) -> Callable[..., Any]:
"""Decorator used in the endpoints to validate IAP JWT assertion header."""
@functools.wraps(request_handler)
def wrapper(*args, **kwargs):
"""Obtains the IAP JWT assertion from the header and validates it.
Args:
*args: from the parent function.
**kwargs: from the parent function.
Returns:
The function it wraps if the token is valid. Otherwise, returns an
UNAUTHORIZED http status code.
"""
if not VALIDATE_IAP_FLG.value:
return request_handler(*args, **kwargs)
token = _get_flask_headers().get(IAP_JWT_HEADER, None)
if _is_valid(token, JWT_AUD_FLG.value):
return request_handler(*args, **kwargs)
return flask.abort(http.HTTPStatus.UNAUTHORIZED)
return wrapper
def _is_valid(iap_jwt_token: Union[str, bytes], expected_audience: str) -> bool:
"""Checks whether an IAP JWT is valid.
Implemented based on
https://cloud.google.com/iap/docs/signed-headers-howto#retrieving_the_user_identity.
Args:
iap_jwt_token: The contents of the X-Goog-IAP-JWT-Assertion header.
expected_audience: The Signed Header JWT audience. See
https://cloud.google.com/iap/docs/signed-headers-howto for details on how
to get this value.
Returns:
Bool to indicate whether the iap_jwt_token is valid.
"""
try:
decoded_jwt = id_token.verify_token(
iap_jwt_token,
requests.Request(),
audience=expected_audience,
certs_url=_CERTS_URL,
)
except (ValueError, exceptions.TransportError):
logging.exception('IAP JWT validation failed.')
return False
if decoded_jwt.get('email', None):
return True
logging.error('No email in IAP JWT.')
return False