services/utils/__init__.py (248 lines of code) (raw):

import json import re import sys import os import traceback from multidict import MultiDict from urllib.parse import urlencode, quote from aiohttp import web from enum import Enum from functools import wraps from typing import Dict import logging import psycopg2 from packaging.version import Version, parse from importlib import metadata USE_SEPARATE_READER_POOL = os.environ.get("USE_SEPARATE_READER_POOL", "0") in ["True", "true", "1"] version = metadata.version("metadata_service") METADATA_SERVICE_VERSION = version METADATA_SERVICE_HEADER = 'METADATA_SERVICE_VERSION' # The latest commit hash of the repository, if set as an environment variable. SERVICE_COMMIT_HASH = os.environ.get("BUILD_COMMIT_HASH", None) # Build time of service, if set as an environment variable. SERVICE_BUILD_TIMESTAMP = os.environ.get("BUILD_TIMESTAMP", None) # Setup log level based on environment variable log_level = os.environ.get('LOGLEVEL', 'INFO').upper() logging.basicConfig(level=log_level) # E.g. set to http://localhost:3000 to allow CORS coming from a UI served from there # Setting to '*' to be maximally loose. ORIGIN_TO_ALLOW_CORS_FROM = os.environ.get('ORIGIN_TO_ALLOW_CORS_FROM', None) async def read_body(request_content): byte_array = bytearray() while not request_content.at_eof(): data = await request_content.read(4) byte_array.extend(data) return json.loads(byte_array.decode("utf-8")) def get_traceback_str(): """Get the traceback as a string.""" exc_info = sys.exc_info() stack = traceback.extract_stack() _tb = traceback.extract_tb(exc_info[2]) full_tb = stack[:-1] + _tb exc_line = traceback.format_exception_only(*exc_info[:2]) return "\n".join( [ "Traceback (most recent call last):", "".join(traceback.format_list(full_tb)), "".join(exc_line), ] ) def http_500(msg, id, traceback_str=get_traceback_str()): # NOTE: worth considering if we want to expose tracebacks in the future in the api messages. body = { 'id': id, 'traceback': traceback_str, 'detail': msg, 'status': 500, 'title': 'Internal Server Error', 'type': 'about:blank' } return web_response(500, body) def handle_exceptions(func): """Catch exceptions and return appropriate HTTP error.""" @wraps(func) async def wrapper(*args, **kwargs): try: return await func(*args, **kwargs) except Exception as err: # pass along an id for the error err_id = getattr(err, 'id', None) # either use provided traceback from subprocess, or generate trace from current process err_trace = getattr(err, 'traceback_str', None) or get_traceback_str() if not err_id: # Log error only in case it is not a known case. err_id = 'generic-error' logging.error(err_trace) return http_500(str(err), err_id, err_trace) return wrapper def format_response(func): """handle formatting""" @wraps(func) async def wrapper(*args, **kwargs): db_response = await func(*args, **kwargs) return web.Response(status=db_response.response_code, body=json.dumps(db_response.body), headers=MultiDict( {METADATA_SERVICE_HEADER: METADATA_SERVICE_VERSION})) return wrapper def web_response(status: int, body): headers = MultiDict( {"Content-Type": "application/json", METADATA_SERVICE_HEADER: METADATA_SERVICE_VERSION}) if not ORIGIN_TO_ALLOW_CORS_FROM: # The aiohttp-cors library actively asserts that this response header # is actively NOT set, before setting it to the original request's origin. # # Therefore, we only want to add this blanket response header iff ORIGIN_TO_ALLOW_CORS_FROM # is not configured (default). headers["Access-Control-Allow-Origin"] = "*" return web.Response(status=status, body=json.dumps(body), headers=headers) def format_qs(query: Dict[str, str], overwrite=None): q = dict(query) if overwrite: for key in overwrite: q[key] = overwrite[key] qs = urlencode(q, safe=':,') return ("?" if len(qs) > 0 else "") + qs def format_baseurl(request: web.BaseRequest): scheme = request.headers.get("X-Forwarded-Proto") or request.scheme host = request.headers.get("X-Forwarded-Host") or request.host # Only get the first Forwarded-Host/Proto in case there are more than one scheme = scheme.split(",")[0].strip() host = host.split(",")[0].strip() baseurl = os.environ.get( "MF_BASEURL", "{scheme}://{host}".format(scheme=scheme, host=host)) return "{baseurl}{path}".format(baseurl=baseurl, path=request.path) def has_heartbeat_capable_version_tag(system_tags): """Check client version tag whether it is known to support heartbeats or not""" try: # only parse for the major.minor.patch version and disregard any trailing bits that might cause issues with comparison. version_tags = [tag for tag in system_tags if tag.startswith('metaflow_version:')] if not version_tags: return False # match versions: major | major.minor | major.minor.patch ver_string = re.match(r"(0|\d+)(\.(0|\d+))*", version_tags[0].lstrip("metaflow_version:"))[0] version = parse(ver_string) if version >= Version("1") and version < Version("2"): return version >= Version("1.14.0") return version >= Version("2.2.12") except Exception: # Treat non-standard versions as heartbeat-enabled by default return True # Database configuration helper # Prioritizes DSN string over individual connection arguments (host,user,...) # # Supports prefix for environment variables: # prefix=MF_METADATA_DB_ -> MF_METADATA_DB_USER=username # # Prioritizes configuration in following order: # # 1. Env DSN string (MF_METADATA_DB_DSN="...") # 2. DSN string as argument (DBConfiguration(dsn="...")) # 3. Env connection arguments (MF_METADATA_DB_HOST="..." MF_METADATA_DB...) # 4. Default connection arguments (DBConfiguration(host="...")) # class DBType(Enum): # The DB host is a read replica READER = 1 # The DB host is a writer instance WRITER = 2 class DBConfiguration(object): host: str = None port: int = None user: str = None password: str = None database_name: str = None # aiopg default pool sizes # https://aiopg.readthedocs.io/en/stable/_modules/aiopg/pool.html#create_pool pool_min: int = None # aiopg default: 1 pool_max: int = None # aiopg default: 10 timeout: int = None # aiopg default: 60 (seconds) _dsn: str = None def __init__(self, dsn: str = None, host: str = "localhost", port: int = 5432, user: str = "postgres", password: str = "postgres", database_name: str = "postgres", ssl_mode: str = "disabled", ssl_cert_path: str = None, ssl_key_path: str = None, ssl_root_cert_path: str = None, prefix="MF_METADATA_DB_", pool_min: int = 1, pool_max: int = 10, timeout: int = 60): self._dsn = os.environ.get(prefix + "DSN", dsn) # Check if it is a BAD DSN String. # if bad dsn string set self._dsn as None. if self._dsn is not None: if not self._is_valid_dsn(self._dsn): self._dsn = None self._host = os.environ.get(prefix + "HOST", host) self._read_replica_host = \ os.environ.get(prefix + "READ_REPLICA_HOST") if USE_SEPARATE_READER_POOL else self._host self._port = int(os.environ.get(prefix + "PORT", port)) self._user = os.environ.get(prefix + "USER", user) self._password = os.environ.get(prefix + "PSWD", password) self._database_name = os.environ.get(prefix + "NAME", database_name) self._ssl_mode = os.environ.get(prefix + "SSL_MODE", ssl_mode) self._ssl_cert_path = os.environ.get(prefix + "SSL_CERT_PATH", ssl_cert_path) self._ssl_key_path = os.environ.get(prefix + "SSL_KEY_PATH", ssl_key_path) self._ssl_root_cert_path = os.environ.get(prefix + "SSL_ROOT_CERT_PATH", ssl_root_cert_path) conn_str_required_values = [ self._host, self._port, self._user, self._password, self._database_name ] some_conn_str_values_missing = any(v is None for v in conn_str_required_values) if self._dsn is None and some_conn_str_values_missing: env_values = ', '.join([ prefix + "HOST", prefix + "PORT", prefix + "USER", prefix + "PSWD", prefix + "NAME", ]) dsn_var = prefix + "DSN" raise Exception( f"Some of the environment variables '{env_values}' are not set. " f"Please either set '{env_values}' or {dsn_var}. " ) self.pool_min = int(os.environ.get(prefix + "POOL_MIN", pool_min)) self.pool_max = int(os.environ.get(prefix + "POOL_MAX", pool_max)) self.timeout = int(os.environ.get(prefix + "TIMEOUT", timeout)) @staticmethod def _is_valid_dsn(dsn): try: psycopg2.extensions.parse_dsn(dsn) return True except psycopg2.ProgrammingError: # This means that the DSN is unparsable. return None def connection_string_url(self, type=None): # postgresql://[user[:password]@][host][:port][/dbname][?param1=value1&...] if type is None or type == DBType.WRITER: base_url = f'postgresql://{quote(self._user)}:{quote(self._password)}@{self._host}:{self._port}/{self._database_name}' elif type == DBType.READER: base_url = f'postgresql://{quote(self._user)}:{quote(self._password)}@{self._read_replica_host}:{self._port}/{self._database_name}' else: raise Exception("Unsupported DBType %s" % type) if (self._ssl_mode in ['allow', 'prefer', 'require', 'verify-ca', 'verify-full']): ssl_query = f'sslmode={self._ssl_mode}' if self._ssl_cert_path is not None: ssl_query = f'{ssl_query}&sslcert={self._ssl_cert_path}' if self._ssl_key_path is not None: ssl_query = f'{ssl_query}&sslkey={self._ssl_key_path}' if self._ssl_root_cert_path is not None: ssl_query = f'{ssl_query}&sslrootcert={self._ssl_root_cert_path}' else: ssl_query = f'sslmode=disable' return f'{base_url}?{ssl_query}' def get_dsn(self, type=None): if self._dsn is None: ssl_mode = self._ssl_mode sslcert = self._ssl_cert_path sslkey = self._ssl_key_path sslrootcert = self._ssl_root_cert_path if (ssl_mode not in ['allow', 'prefer', 'require', 'verify-ca', 'verify-full']): ssl_mode = None sslcert = None sslkey = None sslrootcert = None kwargs = { 'dbname': self._database_name, 'user': self._user, 'host': self._host, 'port': self._port, 'password': self._password, 'sslmode': ssl_mode, 'sslcert': sslcert, 'sslkey': sslkey, 'sslrootcert': sslrootcert } if type == DBType.READER: # We assume that everything except the hostname remains the same for a reader. # At the moment this is a fair assumption for Postgres read replicas. kwargs.update({"host": self._read_replica_host}) return psycopg2.extensions.make_dsn(**{k: v for k, v in kwargs.items() if v is not None}) else: return self._dsn @property def port(self): return self._port @property def password(self): return self._password @property def user(self): return self._user @property def database_name(self): return self._database_name @property def host(self): return self._host @property def read_replica_host(self): return self._read_replica_host