callouts/python/extproc/service/callout_server.py (212 lines of code) (raw):

# Copyright 2024 Google LLC. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """SDK for service callout servers. Provides a customizeable, out of the box, service callout server. Takes in service callouts and performs header and body transformations. Bundled with an optional health check server. Can be set up to use ssl certificates. """ from concurrent import futures from http.server import BaseHTTPRequestHandler from http.server import HTTPServer import logging import ssl from typing import Iterator, Union from typing import Iterable from envoy.service.ext_proc.v3.external_processor_pb2 import HttpBody from envoy.service.ext_proc.v3.external_processor_pb2 import HttpHeaders from envoy.service.ext_proc.v3.external_processor_pb2 import BodyResponse from envoy.service.ext_proc.v3.external_processor_pb2 import HeadersResponse from envoy.service.ext_proc.v3.external_processor_pb2 import ImmediateResponse from envoy.service.ext_proc.v3.external_processor_pb2 import ProcessingRequest from envoy.service.ext_proc.v3.external_processor_pb2 import ProcessingResponse from envoy.service.ext_proc.v3.external_processor_pb2_grpc import ( add_ExternalProcessorServicer_to_server,) from envoy.service.ext_proc.v3.external_processor_pb2_grpc import ( ExternalProcessorServicer,) import grpc from google.protobuf.struct_pb2 import Struct from grpc import ServicerContext def _addr_to_str(address: tuple[str, int]) -> str: """Take in an address tuple and returns a formated ip string. Args: address: Address to transform. Returns: str: f'{address[0]}:{address[1]}' """ return f'{address[0]}:{address[1]}' class HealthCheckService(BaseHTTPRequestHandler): """Server for responding to health check pings.""" def do_GET(self) -> None: """Returns an empty page with 200 status code.""" self.send_response(200) self.end_headers() class CalloutServer: """Server wrapper for managing callout servers and processing callouts. Attributes: address: Address that the main secure server will attempt to connect to, defaults to default_ip:443. port: If specified, overrides the port of address. health_check_address: The health check serving address, defaults to default_ip:80. health_check_port: If set, overrides the port of health_check_address. combined_health_check: If True, does not create a separate health check server. secure_health_check: If True, will use HTTPS as the protocol of the health check server. Requires cert_chain_path and private_key_path to be set. plaintext_address: The non-authenticated address to listen to, defaults to default_ip:8080. plaintext_port: If set, overrides the port of plaintext_address. disable_plaintext: If true, disables the plaintext address of the server. default_ip: If left None, defaults to '0.0.0.0'. cert_chain: PEM Certificate chain used to authenticate secure connections, required for secure servers. cert_chain_path: Relative file path to the cert_chain. private_key: PEM private key of the server. private_key_path: Relative file path pointing to a file containing private_key data. server_thread_count: Threads allocated to the main grpc service. """ def __init__( self, address: tuple[str, int] | None = None, port: int | None = None, health_check_address: tuple[str, int] | None = None, health_check_port: int | None = None, combined_health_check: bool = False, secure_health_check: bool = False, plaintext_address: tuple[str, int] | None = None, plaintext_port: int | None = None, disable_plaintext: bool = False, default_ip: str | None = None, cert_chain: bytes | None = None, cert_chain_path: str | None = './extproc/ssl_creds/chain.pem', private_key: bytes | None = None, private_key_path: str = './extproc/ssl_creds/privatekey.pem', server_thread_count: int = 2, ): self._setup = False self._shutdown = False self._closed = False self._health_check_server: HTTPServer | None = None default_ip = default_ip or '0.0.0.0' self.address: tuple[str, int] = address or (default_ip, 443) if port: self.address = (self.address[0], port) self.plaintext_address: tuple[str, int] | None = None if not disable_plaintext: self.plaintext_address = plaintext_address or (default_ip, 8080) if plaintext_port: self.plaintext_address = (self.plaintext_address[0], plaintext_port) self.health_check_address: tuple[str, int] | None = None if not combined_health_check: self.health_check_address = health_check_address or (default_ip, 80) if health_check_port: self.health_check_address = (self.health_check_address[0], health_check_port) def _read_cert_file(path: str | None) -> bytes | None: if path: with open(path, 'rb') as file: return file.read() return None self.server_thread_count = server_thread_count self.secure_health_check = secure_health_check # Read cert data. self.private_key = private_key or _read_cert_file(private_key_path) self.cert_chain = cert_chain or _read_cert_file(cert_chain_path) if secure_health_check: if not private_key_path: logging.error("Secure health check requires a private_key_path.") return if not cert_chain_path: logging.error("Secure health check requires a cert_chain_path.") return self.health_check_ssl_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) self.health_check_ssl_context.load_cert_chain(certfile=cert_chain_path, keyfile=private_key_path) self._callout_server = _GRPCCalloutService(self) def run(self) -> None: """Start all requested servers and listen for new connections; blocking.""" self._start_servers() self._setup = True try: self._loop_server() except KeyboardInterrupt: logging.info('Server interrupted') finally: self._stop_servers() self._closed = True def _start_servers(self) -> None: """Start the requested servers.""" if self.health_check_address: self._health_check_server = HTTPServer(self.health_check_address, HealthCheckService) protocol = 'HTTP' if self.secure_health_check: protocol = 'HTTPS' self._health_check_server.socket = ( self.health_check_ssl_context.wrap_socket( sock=self._health_check_server.socket,)) logging.info('%s health check server bound to %s.', protocol, _addr_to_str(self.health_check_address)) self._callout_server.start() def _stop_servers(self) -> None: """Close the sockets of all servers, and trigger shutdowns.""" if self._health_check_server: self._health_check_server.server_close() self._health_check_server.shutdown() logging.info('Health check server stopped.') if self._callout_server: self._callout_server.stop() def _loop_server(self) -> None: """Loop server forever, calling shutdown will cause the server to stop.""" # We chose the main serving thread based on what server configuration # was requested. Defaults to the health check thread. if self._health_check_server: logging.info("Health check server started.") self._health_check_server.serve_forever() else: # If the only server requested is a grpc callout server, we wait on the grpc server. self._callout_server.loop() def shutdown(self) -> None: """Tell the server to shutdown, ending all serving threads.""" if self._health_check_server: self._health_check_server.shutdown() if self._callout_server: self._callout_server.stop() def process( self, callout: ProcessingRequest, context: ServicerContext, ) -> ProcessingResponse: """Process incomming callouts. Args: callout: The incomming callout. context: Stream context on the callout. Yields: ProcessingResponse: A response for the incoming callout. """ if callout.HasField('request_headers'): match self.on_request_headers(callout.request_headers, context): case ProcessingResponse() as processing_response: return processing_response case ImmediateResponse() as immediate_headers: return ProcessingResponse(immediate_response=immediate_headers) case HeadersResponse() | None as header_response: return ProcessingResponse(request_headers=header_response) case _: logging.warn("MALFORMED CALLOUT %s", callout) elif callout.HasField('response_headers'): return ProcessingResponse(response_headers=self.on_response_headers( callout.response_headers, context)) elif callout.HasField('request_body'): match self.on_request_body(callout.request_body, context): case ImmediateResponse() as immediate_body: return ProcessingResponse(immediate_response=immediate_body) case BodyResponse() | None as body_response: return ProcessingResponse(request_body=body_response) case _: logging.warn("MALFORMED CALLOUT %s", callout) elif callout.HasField('response_body'): return ProcessingResponse( response_body=self.on_response_body(callout.response_body, context)) return ProcessingResponse() def on_request_headers( self, headers: HttpHeaders, # pylint: disable=unused-argument context: ServicerContext # pylint: disable=unused-argument ) -> Union[None, HeadersResponse, ImmediateResponse, ProcessingResponse]: """Process incoming request headers. Args: headers: Request headers to process. context: RPC context of the incoming callout. Returns: Optional header modification object or a complete response. """ return None def on_response_headers( self, headers: HttpHeaders, # pylint: disable=unused-argument context: ServicerContext # pylint: disable=unused-argument ) -> Union[None, HeadersResponse]: """Process incoming response headers. Args: headers: Response headers to process. context: RPC context of the incoming callout. Returns: Optional header modification object. """ return None def on_request_body( self, body: HttpBody, # pylint: disable=unused-argument context: ServicerContext # pylint: disable=unused-argument ) -> Union[None, BodyResponse, ImmediateResponse]: """Process an incoming request body. Args: headers: Request body to process. context: RPC context of the incoming callout. Returns: Optional body modification object. """ return None def on_response_body( self, body: HttpBody, # pylint: disable=unused-argument context: ServicerContext # pylint: disable=unused-argument ) -> Union[None, BodyResponse]: """Process an incoming response body. Args: headers: Response body to process. context: RPC context of the incoming callout. Returns: Optional body modification object. """ return None class _GRPCCalloutService(ExternalProcessorServicer): """GRPC based Callout server implementation.""" def __init__(self, processor, *args, **kwargs): self._processor = processor self._server = grpc.server( futures.ThreadPoolExecutor(max_workers=processor.server_thread_count)) add_ExternalProcessorServicer_to_server(self, self._server) server_credentials = grpc.ssl_server_credentials( private_key_certificate_chain_pairs=[(processor.private_key, processor.cert_chain)]) address_str = _addr_to_str(processor.address) self._server.add_secure_port(address_str, server_credentials) self._start_msg = f'GRPC callout server started, listening on {address_str}.' if processor.plaintext_address: plaintext_address_str = _addr_to_str(processor.plaintext_address) self._server.add_insecure_port(plaintext_address_str) self._start_msg += f' (secure) and {plaintext_address_str} (plaintext)' def stop(self) -> None: self._server.stop(grace=10) self._server.wait_for_termination(timeout=10) logging.info('GRPC server stopped.') def loop(self) -> None: self._server.wait_for_termination() def start(self) -> None: self._server.start() logging.info(self._start_msg) def Process( self, callout_iterator: Iterable[ProcessingRequest], context: ServicerContext, ) -> Iterator[ProcessingResponse]: """Process the client callout.""" for callout in callout_iterator: yield self._processor.process(callout, context)