elasticapm/contrib/serverless/aws.py (427 lines of code) (raw):
# BSD 3-Clause License
#
# Copyright (c) 2019, Elasticsearch BV
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
import base64
import datetime
import functools
import json
import os
import platform
import re
import time
import urllib
import warnings
from typing import Optional, TypeVar
from urllib.parse import urlencode
import elasticapm
from elasticapm.base import Client, get_client
from elasticapm.conf import constants
from elasticapm.traces import execution_context
from elasticapm.utils import encoding, get_name_from_func, nested_key
from elasticapm.utils.disttracing import TraceParent
from elasticapm.utils.logging import get_logger
SERVERLESS_HTTP_REQUEST = ("api", "elb")
logger = get_logger("elasticapm.serverless")
COLD_START = True
INSTRUMENTED = False
REGISTER_PARTIAL_TRANSACTIONS = True
_AWSLambdaContextT = TypeVar("_AWSLambdaContextT")
def capture_serverless(func: Optional[callable] = None, **kwargs) -> callable:
"""
Decorator for instrumenting AWS Lambda functions.
Example usage:
from elasticapm import capture_serverless
@capture_serverless
def handler(event, context):
return {"statusCode": r.status_code, "body": "Success!"}
Please note that when using the APM Layer and setting AWS_LAMBDA_EXEC_WRAPPER this is not required and
the handler would be instrumented automatically.
"""
if not func:
# This allows for `@capture_serverless()` in addition to
# `@capture_serverless` decorator usage
return functools.partial(capture_serverless, **kwargs)
if kwargs:
warnings.warn(
PendingDeprecationWarning(
"Passing keyword arguments to capture_serverless will be deprecated in the next major release."
)
)
name = kwargs.pop("name", None)
kwargs = prep_kwargs(kwargs)
global INSTRUMENTED
client = get_client()
if not client:
client = Client(**kwargs)
if not client.config.debug and client.config.instrument and client.config.enabled and not INSTRUMENTED:
elasticapm.instrument()
INSTRUMENTED = True
@functools.wraps(func)
def decorated(*args, **kwds):
if len(args) >= 2:
# Saving these for request context later
event, context = args[0:2]
else:
event, context = {}, {}
if not client.config.debug and client.config.instrument and client.config.enabled:
with _lambda_transaction(func, name, client, event, context) as sls:
sls.response = func(*args, **kwds)
return sls.response
else:
return func(*args, **kwds)
return decorated
def prep_kwargs(kwargs=None):
if kwargs is None:
kwargs = {}
# Disable all background threads except for transport
kwargs["metrics_interval"] = "0ms"
kwargs["breakdown_metrics"] = False
if "metrics_sets" not in kwargs and "ELASTIC_APM_METRICS_SETS" not in os.environ:
# Allow users to override metrics sets
kwargs["metrics_sets"] = []
kwargs["central_config"] = False
kwargs["cloud_provider"] = "none"
kwargs["framework_name"] = "AWS Lambda"
if "service_name" not in kwargs and "ELASTIC_APM_SERVICE_NAME" not in os.environ:
kwargs["service_name"] = os.environ["AWS_LAMBDA_FUNCTION_NAME"]
if "service_version" not in kwargs and "ELASTIC_APM_SERVICE_VERSION" not in os.environ:
kwargs["service_version"] = os.environ.get("AWS_LAMBDA_FUNCTION_VERSION")
return kwargs
def should_normalize_headers(event: dict) -> bool:
"""
Helper to decide if we should normalize headers or not depending on the event
Even if the documentation says that headers are lowercased it's not always the case for format version 1.0
https://docs.aws.amazon.com/apigateway/latest/developerguide/http-api-develop-integrations-lambda.html
"""
request_context = event.get("requestContext", {})
return ("elb" in request_context or "requestId" in request_context) and "http" not in request_context
class _lambda_transaction(object):
"""
Context manager for creating transactions around AWS Lambda functions.
Begins and ends a single transaction, waiting for the transport to flush
before releasing the context.
"""
def __init__(
self, func: callable, name: Optional[str], client: Client, event: dict, context: _AWSLambdaContextT
) -> None:
self.func = func
self.name = name or get_name_from_func(func)
self.event = event
self.context = context
self.response = None
self.client = client
def __enter__(self):
"""
Transaction setup
"""
if not isinstance(self.event, dict):
# When `event` is not a dict, it's likely the output of another AWS
# service like Step Functions, and is unlikely to be standardized
# in any way. We just have to rely on our defaults in this case.
self.event = {}
headers = self.event.get("headers") or {}
if headers and should_normalize_headers(self.event):
normalized_headers = {k.lower(): v for k, v in headers.items()}
else:
normalized_headers = headers
trace_parent = TraceParent.from_headers(normalized_headers)
global COLD_START
cold_start = COLD_START
COLD_START = False
self.source = "other"
transaction_type = "request"
transaction_name = os.environ.get("AWS_LAMBDA_FUNCTION_NAME", self.name)
self.httpmethod = (
nested_key(self.event, "requestContext", "httpMethod")
or nested_key(self.event, "requestContext", "http", "method")
or nested_key(self.event, "httpMethod")
)
if self.httpmethod: # http request
if nested_key(self.event, "requestContext", "elb"):
self.source = "elb"
resource = "unknown route"
elif nested_key(self.event, "requestContext", "httpMethod"):
self.source = "api"
# API v1
resource = "/{}{}".format(
nested_key(self.event, "requestContext", "stage"),
nested_key(self.event, "requestContext", "resourcePath"),
)
else:
self.source = "api"
# API v2
route_key = nested_key(self.event, "requestContext", "routeKey")
route_key = f"/{route_key}" if route_key.startswith("$") else route_key.split(" ", 1)[-1]
resource = "/{}{}".format(
nested_key(self.event, "requestContext", "stage"),
route_key,
)
transaction_name = "{} {}".format(self.httpmethod, resource)
elif "Records" in self.event and len(self.event["Records"]) == 1:
record = self.event["Records"][0]
if record.get("eventSource") == "aws:s3": # S3
self.source = "s3"
transaction_name = "{} {}".format(record["eventName"], record["s3"]["bucket"]["name"])
elif record.get("EventSource") == "aws:sns": # SNS
self.source = "sns"
transaction_type = "messaging"
transaction_name = "RECEIVE {}".format(record["Sns"]["TopicArn"].split(":")[5])
elif record.get("eventSource") == "aws:sqs": # SQS
self.source = "sqs"
transaction_type = "messaging"
transaction_name = "RECEIVE {}".format(record["eventSourceARN"].split(":")[5])
if "Records" in self.event:
links = [
TraceParent.from_string(record["messageAttributes"]["traceparent"]["stringValue"])
for record in self.event["Records"][:1000]
if "messageAttributes" in record and "traceparent" in record["messageAttributes"]
]
else:
links = []
self.client.begin_transaction(transaction_type, trace_parent=trace_parent, links=links)
elasticapm.set_transaction_name(transaction_name, override=False)
if self.source in SERVERLESS_HTTP_REQUEST:
elasticapm.set_context(
lambda: get_data_from_request(
self.event,
capture_body=self.client.config.capture_body in ("transactions", "all"),
capture_headers=self.client.config.capture_headers,
),
"request",
)
self.set_metadata_and_context(cold_start)
self.send_partial_transaction()
return self
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
"""
Transaction teardown
"""
if self.response and isinstance(self.response, dict):
elasticapm.set_context(
lambda: get_data_from_response(self.response, capture_headers=self.client.config.capture_headers),
"response",
)
if "statusCode" in self.response:
try:
result = "HTTP {}xx".format(int(self.response["statusCode"]) // 100)
elasticapm.set_transaction_result(result, override=False)
if result == "HTTP 5xx":
elasticapm.set_transaction_outcome(outcome="failure", override=False)
except ValueError:
logger.warning("Lambda function's statusCode was not formed as an int. Assuming 5xx result.")
elasticapm.set_transaction_result("HTTP 5xx", override=False)
elasticapm.set_transaction_outcome(outcome="failure", override=False)
if exc_val:
self.client.capture_exception(exc_info=(exc_type, exc_val, exc_tb), handled=False)
if self.source in SERVERLESS_HTTP_REQUEST:
elasticapm.set_transaction_result("HTTP 5xx", override=False)
elasticapm.set_transaction_outcome(http_status_code=500, override=False)
elasticapm.set_context({"status_code": 500}, "response")
else:
elasticapm.set_transaction_result("failure", override=False)
elasticapm.set_transaction_outcome(outcome="failure", override=False)
self.client.end_transaction()
# Collect any custom+prometheus metrics if enabled
self.client.metrics.collect()
try:
logger.debug("Flushing elasticapm data")
self.client._transport.flush()
logger.debug("Flush complete")
except ValueError:
logger.warning("Flush timed out")
def set_metadata_and_context(self, coldstart: bool) -> None:
"""
Process the metadata and context fields for this request
"""
metadata = {}
cloud_context = {"origin": {"provider": "aws"}}
service_context = {}
message_context = {}
faas = {}
faas["coldstart"] = coldstart
faas["trigger"] = {"type": "other"}
faas["execution"] = self.context.aws_request_id
arn = self.context.invoked_function_arn
if len(arn.split(":")) > 7:
arn = ":".join(arn.split(":")[:7])
faas["id"] = arn
faas["name"] = os.environ.get("AWS_LAMBDA_FUNCTION_NAME")
faas["version"] = os.environ.get("AWS_LAMBDA_FUNCTION_VERSION")
if self.source == "api":
faas["trigger"]["type"] = "http"
faas["trigger"]["request_id"] = self.event["requestContext"]["requestId"]
service_context["origin"] = {"name": self.event["requestContext"]["domainName"]}
service_context["origin"]["id"] = self.event["requestContext"]["apiId"]
service_context["origin"]["version"] = self.event.get("version", "1.0")
cloud_context["origin"] = {}
cloud_context["origin"]["service"] = {"name": "api gateway"}
if ".lambda-url." in self.event["requestContext"]["domainName"]:
cloud_context["origin"]["service"]["name"] = "lambda url"
cloud_context["origin"]["account"] = {"id": self.event["requestContext"]["accountId"]}
cloud_context["origin"]["provider"] = "aws"
elif self.source == "elb":
elb_target_group_arn = self.event["requestContext"]["elb"]["targetGroupArn"]
faas["trigger"]["type"] = "http"
service_context["origin"] = {"name": elb_target_group_arn.split(":")[5].split("/")[1]}
service_context["origin"]["id"] = elb_target_group_arn
cloud_context["origin"] = {}
cloud_context["origin"]["service"] = {"name": "elb"}
cloud_context["origin"]["account"] = {"id": elb_target_group_arn.split(":")[4]}
cloud_context["origin"]["region"] = elb_target_group_arn.split(":")[3]
cloud_context["origin"]["provider"] = "aws"
elif self.source == "sqs":
record = self.event["Records"][0]
faas["trigger"]["type"] = "pubsub"
faas["trigger"]["request_id"] = record["messageId"]
service_context["origin"] = {}
service_context["origin"]["name"] = record["eventSourceARN"].split(":")[5]
service_context["origin"]["id"] = record["eventSourceARN"]
cloud_context["origin"] = {}
cloud_context["origin"]["service"] = {"name": "sqs"}
cloud_context["origin"]["region"] = record["awsRegion"]
cloud_context["origin"]["account"] = {"id": record["eventSourceARN"].split(":")[4]}
cloud_context["origin"]["provider"] = "aws"
message_context["queue"] = {"name": service_context["origin"]["name"]}
if "SentTimestamp" in record["attributes"]:
message_context["age"] = {"ms": int((time.time() * 1000) - int(record["attributes"]["SentTimestamp"]))}
if self.client.config.capture_body in ("transactions", "all") and "body" in record:
message_context["body"] = record["body"]
if self.client.config.capture_headers and record.get("messageAttributes"):
headers = {}
for k, v in record["messageAttributes"].items():
if v and v.get("stringValue"):
headers[k] = v.get("stringValue")
if headers:
message_context["headers"] = headers
elif self.source == "sns":
record = self.event["Records"][0]
faas["trigger"]["type"] = "pubsub"
faas["trigger"]["request_id"] = record["Sns"]["TopicArn"]
service_context["origin"] = {}
service_context["origin"]["name"] = record["Sns"]["TopicArn"].split(":")[5]
service_context["origin"]["id"] = record["Sns"]["TopicArn"]
service_context["origin"]["version"] = record["EventVersion"]
service_context["origin"]["service"] = {"name": "sns"}
cloud_context["origin"] = {}
cloud_context["origin"]["region"] = record["Sns"]["TopicArn"].split(":")[3]
cloud_context["origin"]["account"] = {"id": record["Sns"]["TopicArn"].split(":")[4]}
cloud_context["origin"]["provider"] = "aws"
message_context["queue"] = {"name": service_context["origin"]["name"]}
if "Timestamp" in record["Sns"]:
message_context["age"] = {
"ms": int(
(
datetime.datetime.now()
- datetime.datetime.strptime(record["Sns"]["Timestamp"], r"%Y-%m-%dT%H:%M:%S.%fZ")
).total_seconds()
* 1000
)
}
if self.client.config.capture_body in ("transactions", "all") and "Message" in record["Sns"]:
message_context["body"] = record["Sns"]["Message"]
if self.client.config.capture_headers and record["Sns"].get("MessageAttributes"):
headers = {}
for k, v in record["Sns"]["MessageAttributes"].items():
if v and v.get("Type") == "String":
headers[k] = v.get("Value")
if headers:
message_context["headers"] = headers
elif self.source == "s3":
record = self.event["Records"][0]
faas["trigger"]["type"] = "datasource"
faas["trigger"]["request_id"] = record["responseElements"]["x-amz-request-id"]
service_context["origin"] = {}
service_context["origin"]["name"] = record["s3"]["bucket"]["name"]
service_context["origin"]["id"] = record["s3"]["bucket"]["arn"]
service_context["origin"]["version"] = record["eventVersion"]
cloud_context["origin"] = {}
cloud_context["origin"]["service"] = {"name": "s3"}
cloud_context["origin"]["region"] = record["awsRegion"]
cloud_context["origin"]["provider"] = "aws"
metadata["service"] = {}
metadata["service"]["name"] = self.client.config.service_name
metadata["service"]["framework"] = {"name": "AWS Lambda"}
metadata["service"]["runtime"] = {
"name": os.environ.get("AWS_EXECUTION_ENV"),
"version": platform.python_version(),
}
metadata["service"]["version"] = self.client.config.service_version
metadata["service"]["node"] = {"configured_name": os.environ.get("AWS_LAMBDA_LOG_STREAM_NAME")}
# This is the one piece of metadata that requires deep merging. We add it manually
# here to avoid having to deep merge in _transport.add_metadata()
node = nested_key(self.client.get_service_info(), "node")
if node:
metadata["service"]["node"] = node
metadata["cloud"] = {}
metadata["cloud"]["provider"] = "aws"
metadata["cloud"]["region"] = os.environ.get("AWS_REGION")
metadata["cloud"]["service"] = {"name": "lambda"}
metadata["cloud"]["account"] = {"id": arn.split(":")[4]}
elasticapm.set_context(cloud_context, "cloud")
elasticapm.set_context(service_context, "service")
# faas doesn't actually belong in context, but we handle this in to_dict
elasticapm.set_context(faas, "faas")
if message_context:
elasticapm.set_context(message_context, "message")
self.client.add_extra_metadata(metadata)
def send_partial_transaction(self) -> None:
"""
We synchronously send the (partial) transaction to the Lambda Extension
so that the transaction can be reported even if the lambda runtime times
out before we can report the transaction.
This is pretty specific to the HTTP transport. If we ever add other
transports, we will need to clean this up.
"""
global REGISTER_PARTIAL_TRANSACTIONS
if (
REGISTER_PARTIAL_TRANSACTIONS
and os.environ.get("ELASTIC_APM_LAMBDA_APM_SERVER")
and ("localhost" in self.client.config.server_url or "127.0.0.1" in self.client.config.server_url)
):
transaction = execution_context.get_transaction()
transport = self.client._transport
logger.debug("Sending partial transaction and early metadata to the lambda extension...")
data = transport._json_serializer({"metadata": self.client.build_metadata()}) + "\n"
data += transport._json_serializer({"transaction": transaction.to_dict()})
partial_transaction_url = urllib.parse.urljoin(
self.client.config.server_url
if self.client.config.server_url.endswith("/")
else self.client.config.server_url + "/",
"register/transaction",
)
try:
transaction.pause_sampling = True
transport.send(
data,
custom_url=partial_transaction_url,
custom_headers={
"x-elastic-aws-request-id": self.context.aws_request_id,
"Content-Type": "application/vnd.elastic.apm.transaction+ndjson",
},
)
except Exception as e:
if re.match(r"HTTP [4,5]\d\d", str(e)):
REGISTER_PARTIAL_TRANSACTIONS = False
logger.info(
"APM Lambda Extension does not support partial transactions. "
"Disabling partial transaction registration."
)
else:
logger.warning("Failed to send partial transaction to APM Lambda Extension", exc_info=True)
finally:
transaction.pause_sampling = False
def get_data_from_request(event: dict, capture_body: bool = False, capture_headers: bool = True) -> dict:
"""
Capture context data from API gateway event
"""
result = {}
if capture_headers and "headers" in event:
result["headers"] = event["headers"]
method = (
nested_key(event, "requestContext", "httpMethod")
or nested_key(event, "requestContext", "http", "method")
or nested_key(event, "httpMethod")
)
if not method:
# Not API Gateway
return result
result["method"] = method
if method in constants.HTTP_WITH_BODY and "body" in event:
body = event["body"]
if capture_body:
if event.get("isBase64Encoded"):
body = base64.b64decode(body)
else:
try:
jsonbody = json.loads(body)
body = jsonbody
except Exception:
pass
if body is not None:
result["body"] = body if capture_body else "[REDACTED]"
result["url"] = get_url_dict(event)
return result
def get_data_from_response(response: dict, capture_headers: bool = True) -> dict:
"""
Capture response data from lambda return
"""
result = {}
if "statusCode" in response:
try:
result["status_code"] = int(response["statusCode"])
except ValueError:
# statusCode wasn't formed as an int
# we don't log here, as we will have already logged at transaction.result handling
result["status_code"] = 500
if capture_headers and "headers" in response:
result["headers"] = response["headers"]
return result
def get_url_dict(event: dict) -> dict:
"""
Reconstruct URL from API Gateway
"""
headers = event.get("headers") or {}
protocol = headers.get("X-Forwarded-Proto", headers.get("x-forwarded-proto", "https"))
host = headers.get("Host", headers.get("host", ""))
stage = nested_key(event, "requestContext", "stage") or ""
raw_path = event.get("rawPath", "")
if stage:
stage = "/" + stage
raw_path = raw_path.split(stage)[-1]
path = event.get("path", raw_path)
port = headers.get("X-Forwarded-Port", headers.get("x-forwarded-port"))
query = ""
if "rawQueryString" in event:
query = event["rawQueryString"]
elif event.get("queryStringParameters"):
if stage: # api requires parameters encoding to build correct url
query = "?" + urlencode(event["queryStringParameters"])
else: # for elb we do not have the stage
query = "?" + "&".join(["{}={}".format(k, v) for k, v in event["queryStringParameters"].items()])
url = protocol + "://" + host + stage + path + query
url_dict = {
"full": encoding.keyword_field(url),
"protocol": protocol,
"hostname": encoding.keyword_field(host),
"pathname": encoding.keyword_field(stage + path),
}
if port:
url_dict["port"] = port
if query:
url_dict["search"] = encoding.keyword_field(query)
return url_dict