awslambdaric/lambda_runtime_client.py (118 lines of code) (raw):
"""
Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
"""
import sys
from awslambdaric import __version__
from .lambda_runtime_exception import FaultException
from .lambda_runtime_marshaller import to_json
ERROR_TYPE_HEADER = "Lambda-Runtime-Function-Error-Type"
def _user_agent():
py_version = (
f"{sys.version_info.major}.{sys.version_info.minor}.{sys.version_info.micro}"
)
pkg_version = __version__
return f"aws-lambda-python/{py_version}-{pkg_version}"
try:
import runtime_client
runtime_client.initialize_client(_user_agent())
except ImportError:
runtime_client = None
from .lambda_runtime_marshaller import LambdaMarshaller
class InvocationRequest(object):
def __init__(self, **kwds):
self.__dict__.update(kwds)
def __eq__(self, other):
return self.__dict__ == other.__dict__
class LambdaRuntimeClientError(Exception):
def __init__(self, endpoint, response_code, response_body):
self.endpoint = endpoint
self.response_code = response_code
self.response_body = response_body
super().__init__(
f"Request to Lambda Runtime '{endpoint}' endpoint failed. Reason: '{response_code}'. Response body: '{response_body}'"
)
class LambdaRuntimeClient(object):
marshaller = LambdaMarshaller()
"""marshaller is a class attribute that determines the unmarshalling and marshalling logic of a function's event
and response. It allows for function authors to override the the default implementation, LambdaMarshaller which
unmarshals and marshals JSON, to an instance of a class that implements the same interface."""
def __init__(self, lambda_runtime_address, use_thread_for_polling_next=False):
self.lambda_runtime_address = lambda_runtime_address
self.use_thread_for_polling_next = use_thread_for_polling_next
if self.use_thread_for_polling_next:
# Conditionally import only for the case when TPE is used in this class.
from concurrent.futures import ThreadPoolExecutor
# Not defining symbol as global to avoid relying on TPE being imported unconditionally.
self.ThreadPoolExecutor = ThreadPoolExecutor
def call_rapid(
self, http_method, endpoint, expected_http_code, payload=None, headers=None
):
# These imports are heavy-weight. They implicitly trigger `import ssl, hashlib`.
# Importing them lazily to speed up critical path of a common case.
import http.client
runtime_connection = http.client.HTTPConnection(self.lambda_runtime_address)
runtime_connection.connect()
if http_method == "GET":
runtime_connection.request(http_method, endpoint)
else:
runtime_connection.request(
http_method, endpoint, to_json(payload), headers=headers
)
response = runtime_connection.getresponse()
response_body = response.read()
if response.code != expected_http_code:
raise LambdaRuntimeClientError(endpoint, response.code, response_body)
def post_init_error(self, error_response_data, error_type_override=None):
import http
endpoint = "/2018-06-01/runtime/init/error"
headers = {
ERROR_TYPE_HEADER: (
error_type_override
if error_type_override
else error_response_data["errorType"]
)
}
self.call_rapid(
"POST", endpoint, http.HTTPStatus.ACCEPTED, error_response_data, headers
)
def restore_next(self):
import http
endpoint = "/2018-06-01/runtime/restore/next"
self.call_rapid("GET", endpoint, http.HTTPStatus.OK)
def report_restore_error(self, restore_error_data):
import http
endpoint = "/2018-06-01/runtime/restore/error"
headers = {ERROR_TYPE_HEADER: FaultException.AFTER_RESTORE_ERROR}
self.call_rapid(
"POST", endpoint, http.HTTPStatus.ACCEPTED, restore_error_data, headers
)
def wait_next_invocation(self):
# Calling runtime_client.next() from a separate thread unblocks the main thread,
# which can then process signals.
if self.use_thread_for_polling_next:
try:
# TPE class is supposed to be registered at construction time and be ready to use.
with self.ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(runtime_client.next)
response_body, headers = future.result()
except Exception as e:
raise FaultException(
FaultException.LAMBDA_RUNTIME_CLIENT_ERROR,
"LAMBDA_RUNTIME Failed to get next invocation: {}".format(str(e)),
None,
)
else:
response_body, headers = runtime_client.next()
return InvocationRequest(
invoke_id=headers.get("Lambda-Runtime-Aws-Request-Id"),
x_amzn_trace_id=headers.get("Lambda-Runtime-Trace-Id"),
invoked_function_arn=headers.get("Lambda-Runtime-Invoked-Function-Arn"),
deadline_time_in_ms=headers.get("Lambda-Runtime-Deadline-Ms"),
client_context=headers.get("Lambda-Runtime-Client-Context"),
cognito_identity=headers.get("Lambda-Runtime-Cognito-Identity"),
content_type=headers.get("Content-Type"),
event_body=response_body,
)
def post_invocation_result(
self, invoke_id, result_data, content_type="application/json"
):
runtime_client.post_invocation_result(
invoke_id,
(
result_data
if isinstance(result_data, bytes)
else result_data.encode("utf-8")
),
content_type,
)
def post_invocation_error(self, invoke_id, error_response_data, xray_fault):
max_header_size = 1024 * 1024 # 1MiB
xray_fault = xray_fault if len(xray_fault.encode()) < max_header_size else ""
runtime_client.post_error(invoke_id, error_response_data, xray_fault)