elasticapm/instrumentation/packages/dbapi2.py (195 lines of code) (raw):
# BSD 3-Clause License
#
# Copyright (c) 2019, Elasticsearch BV
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
"""Provides classes to instrument dbapi2 providers
https://www.python.org/dev/peps/pep-0249/
"""
import re
import string
import wrapt
from elasticapm.instrumentation.packages.base import AbstractInstrumentedModule
from elasticapm.traces import capture_span
from elasticapm.utils.encoding import force_text, shorten
class Literal(object):
def __init__(self, literal_type, content) -> None:
self.literal_type = literal_type
self.content = content
def __eq__(self, other):
return isinstance(other, Literal) and self.literal_type == other.literal_type and self.content == other.content
def __repr__(self):
return "<Literal {}{}{}>".format(self.literal_type, self.content, self.literal_type)
def look_for_table(sql, keyword):
tokens = tokenize(sql)
table_name = _scan_for_table_with_tokens(tokens, keyword)
if isinstance(table_name, Literal):
table_name = table_name.content.strip(table_name.literal_type)
return table_name
def _scan_for_table_with_tokens(tokens, keyword):
seen_keyword = False
for idx, lexeme in scan(tokens):
if seen_keyword:
if lexeme == "(":
return _scan_for_table_with_tokens(tokens[idx:], keyword)
else:
return lexeme
if isinstance(lexeme, str) and lexeme.upper() == keyword:
seen_keyword = True
def tokenize(sql):
# split on anything that is not a word character or a square bracket, excluding dots
return [t for t in re.split(r"([^\w.\[\]])", sql) if t != ""]
def scan(tokens):
literal_start_idx = None
literal_started = None
prev_was_escape = False
lexeme = []
digits = set(string.digits)
i = 0
while i < len(tokens):
token = tokens[i]
if literal_start_idx:
if prev_was_escape:
prev_was_escape = False
lexeme.append(token)
else:
if token == literal_started:
if literal_started == "'" and len(tokens) > i + 1 and tokens[i + 1] == "'": # double quotes
i += 1
lexeme.append("'")
else:
yield i, Literal(literal_started, "".join(lexeme))
literal_start_idx = None
literal_started = None
lexeme = []
else:
if token == "\\":
prev_was_escape = token
else:
prev_was_escape = False
lexeme.append(token)
elif literal_start_idx is None:
if token in ["'", '"', "`"]:
literal_start_idx = i
literal_started = token
elif token == "$":
# exclude query parameters that have a digit following the dollar
if True and len(tokens) > i + 1 and tokens[i + 1] in digits:
yield i, token
i += 1
continue
# Postgres can use arbitrary characters between two $'s as a
# literal separation token, e.g.: $fish$ literal $fish$
# This part will detect that and skip over the literal.
try:
# Closing dollar of the opening quote,
# i.e. the second $ in the first $fish$
closing_dollar_idx = tokens.index("$", i + 1)
except ValueError:
pass
else:
quote = tokens[i : closing_dollar_idx + 1]
length = len(quote)
# Opening dollar of the closing quote,
# i.e. the first $ in the second $fish$
closing_quote_idx = closing_dollar_idx + 1
while True:
try:
closing_quote_idx = tokens.index("$", closing_quote_idx)
except ValueError:
break
if tokens[closing_quote_idx : closing_quote_idx + length] == quote:
yield i, Literal(
"".join(quote), "".join(tokens[closing_dollar_idx + 1 : closing_quote_idx])
)
i = closing_quote_idx + length
break
closing_quote_idx += 1
else:
if token != " ":
yield i, token
i += 1
if lexeme:
yield i, lexeme
def extract_signature(sql):
"""
Extracts a minimal signature from a given SQL query
:param sql: the SQL statement
:return: a string representing the signature
"""
sql = force_text(sql)
sql = sql.strip()
first_space = sql.find(" ")
if first_space < 0:
return sql
second_space = sql.find(" ", first_space + 1)
sql_type = sql[0:first_space].upper()
if sql_type in ["INSERT", "DELETE"]:
keyword = "INTO" if sql_type == "INSERT" else "FROM"
sql_type = sql_type + " " + keyword
object_name = look_for_table(sql, keyword)
elif sql_type in ["CREATE", "DROP"]:
# 2nd word is part of SQL type
sql_type = sql_type + sql[first_space:second_space]
object_name = ""
elif sql_type == "UPDATE":
object_name = look_for_table(sql, "UPDATE")
elif sql_type == "SELECT":
# Name is first table
try:
sql_type = "SELECT FROM"
object_name = look_for_table(sql, "FROM")
except Exception:
object_name = ""
elif sql_type in ["EXEC", "EXECUTE"]:
sql_type = "EXECUTE"
end = second_space if second_space > first_space else len(sql)
object_name = sql[first_space + 1 : end]
elif sql_type == "CALL":
first_paren = sql.find("(", first_space)
end = first_paren if first_paren > first_space else len(sql)
procedure_name = sql[first_space + 1 : end].rstrip(";")
object_name = procedure_name + "()"
else:
# No name
object_name = ""
signature = " ".join(filter(bool, [sql_type, object_name]))
return signature
QUERY_ACTION = "query"
EXEC_ACTION = "exec"
PROCEDURE_STATEMENTS = ["EXEC", "EXECUTE", "CALL"]
def extract_action_from_signature(signature, default):
if signature.split(" ")[0] in PROCEDURE_STATEMENTS:
return EXEC_ACTION
return default
class CursorProxy(wrapt.ObjectProxy):
provider_name = None
DML_QUERIES = ("INSERT", "DELETE", "UPDATE")
def __init__(self, wrapped, destination_info=None) -> None:
super(CursorProxy, self).__init__(wrapped)
self._self_destination_info = destination_info or {}
def callproc(self, procname, params=None):
return self._trace_sql(self.__wrapped__.callproc, procname, params, action=EXEC_ACTION)
def execute(self, sql, params=None):
return self._trace_sql(self.__wrapped__.execute, sql, params)
def executemany(self, sql, param_list):
return self._trace_sql(self.__wrapped__.executemany, sql, param_list)
def _bake_sql(self, sql):
"""
Method to turn the "sql" argument into a string. Most database backends simply return
the given object, as it is already a string
"""
return sql
def _trace_sql(self, method, sql, params, action=QUERY_ACTION):
sql_string = self._bake_sql(sql)
if action == EXEC_ACTION:
signature = sql_string + "()"
else:
signature = self.extract_signature(sql_string)
action = extract_action_from_signature(signature, action)
# Truncate sql_string to 10000 characters to prevent large queries from
# causing an error to APM server.
sql_string = shorten(sql_string, string_length=10000)
with capture_span(
signature,
span_type="db",
span_subtype=self.provider_name,
span_action=action,
extra={
"db": {"type": "sql", "statement": sql_string, "instance": getattr(self, "_self_database", None)},
"destination": self._self_destination_info,
},
skip_frames=1,
leaf=True,
) as span:
if params is None:
result = method(sql)
else:
result = method(sql, params)
# store "rows affected", but only for DML queries like insert/update/delete
if span and self.rowcount not in (-1, None) and signature.startswith(self.DML_QUERIES):
span.update_context("db", {"rows_affected": self.rowcount})
return result
def extract_signature(self, sql):
raise NotImplementedError()
class ConnectionProxy(wrapt.ObjectProxy):
cursor_proxy = CursorProxy
def __init__(self, wrapped, destination_info=None) -> None:
super(ConnectionProxy, self).__init__(wrapped)
self._self_destination_info = destination_info
def cursor(self, *args, **kwargs):
return self.cursor_proxy(self.__wrapped__.cursor(*args, **kwargs), self._self_destination_info)
class DbApi2Instrumentation(AbstractInstrumentedModule):
connect_method = None
def call(self, module, method, wrapped, instance, args, kwargs):
return ConnectionProxy(wrapped(*args, **kwargs))
def call_if_sampling(self, module, method, wrapped, instance, args, kwargs):
# Contrasting to the superclass implementation, we *always* want to
# return a proxied connection, even if there is no ongoing elasticapm
# transaction yet. This ensures that we instrument the cursor once
# the transaction started.
return self.call(module, method, wrapped, instance, args, kwargs)