callouts/python/extproc/example/jwt_auth/service_callout_example.py (63 lines of code) (raw):

# Copyright 2024 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import logging from re import DEBUG from typing import Union import jwt from typing import Union, Any from jwt.exceptions import InvalidTokenError from grpc import ServicerContext from envoy.service.ext_proc.v3 import external_processor_pb2 as service_pb2 from extproc.service import callout_server from extproc.service import callout_tools def extract_jwt_token( request_headers: service_pb2.HttpHeaders, ) -> Union[str, None]: """ Extracts the JWT token from the request headers, specifically looking for the 'Authorization' header and parsing out the token part. Args: request_headers (service_pb2.HttpHeaders): The HTTP headers received in the request. Returns: str: The extracted JWT token if found, otherwise None. Example: Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6... -> Returns: eyJhbGciOiJIUzI1NiIsInR5cCI6... """ decoded_headers = ( header.raw_value.decode('utf-8') or header.value for header in request_headers.headers.headers if header.key.lower() == 'authorization' ) result = next(decoded_headers, None) if result is None: return result return result.strip().split(' ')[-1] def validate_jwt_token( key: bytes, request_headers: service_pb2.HttpHeaders, algorithm: str, context: ServicerContext, ) -> Union[Any, None]: """ Validates the JWT token extracted from the request headers using a specified public key and algorithm. If valid, returns the decoded JWT payload; otherwise, logs an error and returns None. Args: key (bytes): The public key used for token validation. request_headers (service_pb2.HttpHeaders): The HTTP headers received in the request, used to extract the JWT token. algorithm (str): The algorithm with which the JWT was signed (e.g., 'RS256'). context: RPC context of the incoming callout. Returns: dict | None: The decoded JWT if validation is successful, None if the token is invalid or an error occurs. Raises: InvalidTokenError: If the token is invalid or decoding fails. Example: Authorization: Bearer eyJhbGciOiJIUzI1NiIsInR5cCI6... -> Returns: {'sub': '1234567890', 'name': 'John Doe', 'iat': 1712173461, 'exp': 2075658261} """ jwt_token = extract_jwt_token(request_headers) if jwt_token is None: callout_tools.deny_callout(context, 'No Authorization token found.') return None try: decoded = jwt.decode(jwt_token, key, algorithms=[algorithm]) logging.info('Approved - Decoded Values: %s', decoded) return decoded except InvalidTokenError: return None class CalloutServerExample(callout_server.CalloutServer): """Example callout server. For request header callouts we provide a mutation to add multiple headers based on the decoded fields for example '{decoded-name: John Doe}', and to clear the route cache if the JWT Authorization is valid. A valid token example value can be found below. Valid Token for RS256: eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiYWRtaW4iOnRydWUsImlhdCI6MTcxMjE3MzQ2MSwiZXhwIjoyMDc1NjU4MjYxfQ.Vv-Lwn1z8BbVBGm-T1EKxv6T3XKCeRlvRrRmdu8USFdZUoSBK_aThzwzM2T8hlpReYsX9YFdJ3hMfq6OZTfHvfPLXvAt7iSKa03ZoPQzU8bRGzYy8xrb0ZQfrejGfHS5iHukzA8vtI2UAJ_9wFQiY5_VGHOBv9116efslbg-_gItJ2avJb0A0yr5uUwmE336rYEwgm4DzzfnTqPt8kcJwkONUsjEH__mePrva1qDT4qtfTPQpGa35TW8n9yZqse3h1w3xyxUfJd3BlDmoz6pQp2CvZkhdQpkWA1bnwpdqSDC7bHk4tYX6K5Q19na-2ff7gkmHZHJr0G9e_vAhQiE5w """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._load_public_key('./extproc/ssl_creds/publickey.pem') def _load_public_key(self, path: str) -> None: with open(path, 'rb') as key_file: self.public_key = key_file.read() def on_request_headers( self, headers: service_pb2.HttpHeaders, context: ServicerContext ) -> Union[service_pb2.HeadersResponse, None]: """Deny token if validation fails and return an error message. See :py:meth:`callouts.python.extproc.service.callout_tools.deny_request` for more information. If the token is valid, apply a header mutation. See :py:meth:`callouts.python.extproc.service.callout_tools.add_header_mutation` for more information. See base method: :py:meth:`callouts.python.extproc.service.callout_server.CalloutServer.on_request_headers`. """ logging.debug(headers) decoded = validate_jwt_token(self.public_key, headers, 'RS256', context) if decoded is not None: decoded_items = [ ('decoded-' + key, str(value)) for key, value in decoded.items() ] return callout_tools.add_header_mutation( add=decoded_items, clear_route_cache=True ) else: callout_tools.deny_callout(context, 'Authorization token is invalid.') if __name__ == '__main__': # Useful command line args. args = callout_tools.add_command_line_args().parse_args() # Set the logging debug level. logging.basicConfig(level=logging.DEBUG) # Run the gRPC service. CalloutServerExample(**vars(args)).run()