elasticapm/contrib/starlette/__init__.py (131 lines of code) (raw):
# BSD 3-Clause License
#
# Copyright (c) 2012, the Sentry Team, see AUTHORS for more details
# Copyright (c) 2019, Elasticsearch BV
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
from __future__ import absolute_import
import asyncio
import functools
from typing import Dict, Optional
import starlette
from starlette.datastructures import Headers
from starlette.requests import Request
from starlette.routing import Match, Mount
from starlette.types import ASGIApp, Message
import elasticapm
import elasticapm.instrumentation.control
from elasticapm.base import Client, get_client
from elasticapm.conf import constants
from elasticapm.contrib.asyncio.traces import set_context
from elasticapm.contrib.starlette.utils import get_body, get_data_from_request, get_data_from_response
from elasticapm.utils.disttracing import TraceParent
from elasticapm.utils.encoding import long_field
from elasticapm.utils.logging import get_logger
logger = get_logger("elasticapm.errors.client")
def make_apm_client(config: Optional[Dict] = None, client_cls=Client, **defaults) -> Client:
"""Builds ElasticAPM client.
Args:
config (dict): Dictionary of Client configuration. All keys must be uppercase. See `elasticapm.conf.Config`.
client_cls (Client): Must be Client or its child.
**defaults: Additional parameters for Client. See `elasticapm.base.Client`
Returns:
Client
"""
if "framework_name" not in defaults:
defaults["framework_name"] = "starlette"
defaults["framework_version"] = starlette.__version__
return client_cls(config, **defaults)
class ElasticAPM:
"""
Starlette / FastAPI middleware for Elastic APM capturing.
>>> apm = make_apm_client({
>>> 'SERVICE_NAME': 'myapp',
>>> 'DEBUG': True,
>>> 'SERVER_URL': 'http://localhost:8200',
>>> 'CAPTURE_HEADERS': True,
>>> 'CAPTURE_BODY': 'all'
>>> })
>>> app.add_middleware(ElasticAPM, client=apm)
Pass an arbitrary SERVICE_NAME and SECRET_TOKEN::
>>> elasticapm = ElasticAPM(app, service_name='myapp', secret_token='asdasdasd')
Pass an explicit client (don't pass in additional options in this case)::
>>> elasticapm = ElasticAPM(app, client=client)
Capture an exception::
>>> try:
>>> 1 / 0
>>> except ZeroDivisionError:
>>> elasticapm.capture_exception()
Capture a message::
>>> elasticapm.capture_message('hello, world!')
"""
def __init__(self, app: ASGIApp, client: Optional[Client] = None, **kwargs) -> None:
"""
Args:
app (ASGIApp): Starlette app
client (Client): ElasticAPM Client
"""
if client:
self.client = client
else:
self.client = get_client()
if not self.client:
self.client = make_apm_client(**kwargs)
if self.client.config.instrument and self.client.config.enabled:
elasticapm.instrumentation.control.instrument()
# If we ever make this a general-use ASGI middleware we should use
# `asgiref.compatibility.guarantee_single_callable(app)` here
self.app = app
async def __call__(self, scope, receive, send):
"""
Args:
scope: ASGI scope dictionary
receive: receive awaitable callable
send: send awaitable callable
"""
# we only handle the http scope, skip anything else.
if scope["type"] != "http" or (scope["type"] == "http" and self.client.should_ignore_url(scope["path"])):
await self.app(scope, receive, send)
return
@functools.wraps(send)
async def wrapped_send(message) -> None:
if message.get("type") == "http.response.start":
await set_context(
lambda: get_data_from_response(message, self.client.config, constants.TRANSACTION), "response"
)
result = "HTTP {}xx".format(message["status"] // 100)
elasticapm.set_transaction_result(result, override=False)
elasticapm.set_transaction_outcome(http_status_code=message["status"], override=False)
await send(message)
_mocked_receive = None
_request_receive = None
# begin the transaction before capturing the body to get that time accounted
trace_parent = TraceParent.from_headers(dict(Headers(scope=scope)))
self.client.begin_transaction("request", trace_parent=trace_parent)
if self.client.config.capture_body != "off":
# When we consume the body from receive, we replace the streaming
# mechanism with a mocked version -- this workaround came from
# https://github.com/encode/starlette/issues/495#issuecomment-513138055
body = []
while True:
message = await receive()
if not message:
break
if message["type"] == "http.request":
b = message.get("body", b"")
if b:
body.append(b)
if not message.get("more_body", False):
break
if message["type"] == "http.disconnect":
break
joined_body = b"".join(body)
async def mocked_receive() -> Message:
await asyncio.sleep(0)
return {"type": "http.request", "body": long_field(joined_body)}
_mocked_receive = mocked_receive
async def request_receive() -> Message:
await asyncio.sleep(0)
return {"type": "http.request", "body": joined_body}
_request_receive = request_receive
request = Request(scope, receive=_mocked_receive or receive)
await self._request_started(request)
# We don't end the transaction here, we rely on the starlette
# instrumentation of ServerErrorMiddleware to end the transaction
try:
await self.app(scope, _request_receive or receive, wrapped_send)
elasticapm.set_transaction_outcome(constants.OUTCOME.SUCCESS, override=False)
except Exception:
await self.capture_exception(
context={"request": await get_data_from_request(request, self.client.config, constants.ERROR)}
)
elasticapm.set_transaction_result("HTTP 5xx", override=False)
elasticapm.set_transaction_outcome(constants.OUTCOME.FAILURE, override=False)
elasticapm.set_context({"status_code": 500}, "response")
raise
async def capture_exception(self, *args, **kwargs) -> None:
"""Captures your exception.
Args:
*args:
**kwargs:
"""
self.client.capture_exception(*args, **kwargs)
async def capture_message(self, *args, **kwargs) -> None:
"""Captures your message.
Args:
*args: Whatever
**kwargs: Whatever
"""
self.client.capture_message(*args, **kwargs)
async def _request_started(self, request: Request) -> None:
"""Captures the begin of the request processing to APM.
Args:
request (Request)
"""
# When we consume the body, we replace the streaming mechanism with
# a mocked version -- this workaround came from
# https://github.com/encode/starlette/issues/495#issuecomment-513138055
# and we call the workaround here to make sure that regardless of
# `capture_body` settings, we will have access to the body if we need it.
if self.client.config.capture_body != "off":
await get_body(request)
await set_context(lambda: get_data_from_request(request, self.client.config, constants.TRANSACTION), "request")
transaction_name = self.get_route_name(request) or request.url.path
elasticapm.set_transaction_name("{} {}".format(request.method, transaction_name), override=False)
def get_route_name(self, request: Request) -> str:
app = request.app
scope = request.scope
routes = app.routes
route_name = self._get_route_name(scope, routes)
# Starlette magically redirects requests if the path matches a route name with a trailing slash
# appended or removed. To not spam the transaction names list, we do the same here and put these
# redirects all in the same "redirect trailing slashes" transaction name
if not route_name and app.router.redirect_slashes and scope["path"] != "/":
redirect_scope = dict(scope)
if scope["path"].endswith("/"):
redirect_scope["path"] = scope["path"][:-1]
trim = True
else:
redirect_scope["path"] = scope["path"] + "/"
trim = False
route_name = self._get_route_name(redirect_scope, routes)
if route_name is not None:
route_name = route_name + "/" if trim else route_name[:-1]
return route_name
def _get_route_name(self, scope, routes, route_name=None):
for route in routes:
match, child_scope = route.matches(scope)
if match == Match.FULL:
route_name = route.path
child_scope = {**scope, **child_scope}
if isinstance(route, Mount) and route.routes:
child_route_name = self._get_route_name(child_scope, route.routes, route_name)
if child_route_name is None:
route_name = None
else:
route_name += child_route_name
return route_name
elif match == Match.PARTIAL and route_name is None:
route_name = route.path