esrally/utils/net.py (213 lines of code) (raw):

# Licensed to Elasticsearch B.V. under one or more contributor # license agreements. See the NOTICE file distributed with # this work for additional information regarding copyright # ownership. Elasticsearch B.V. licenses this file to you under # the Apache License, Version 2.0 (the "License"); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import functools import logging import os import socket import time import urllib.error from urllib.parse import parse_qs, quote, urlencode, urlparse, urlunparse import certifi import urllib3 from esrally import exceptions from esrally.utils import console, convert _HTTP = None _HTTPS = None def __proxy_manager_from_env(env_var, logger): proxy_url = os.getenv(env_var.lower()) or os.getenv(env_var.upper()) if not proxy_url: env_var = "all_proxy" proxy_url = os.getenv(env_var) or os.getenv(env_var.upper()) if proxy_url and len(proxy_url) > 0: parsed_url = urllib3.util.parse_url(proxy_url) logger.info("Connecting via proxy URL [%s] to the Internet (picked up from the environment variable [%s]).", proxy_url, env_var) return urllib3.ProxyManager( proxy_url, cert_reqs="CERT_REQUIRED", ca_certs=certifi.where(), # appropriate headers will only be set if there is auth info proxy_headers=urllib3.make_headers(proxy_basic_auth=parsed_url.auth), ) else: logger.info("Connecting directly to the Internet (no proxy support) for [%s].", env_var) return urllib3.PoolManager(cert_reqs="CERT_REQUIRED", ca_certs=certifi.where()) def init(): logger = logging.getLogger(__name__) global _HTTP, _HTTPS _HTTP = __proxy_manager_from_env("http_proxy", logger) _HTTPS = __proxy_manager_from_env("https_proxy", logger) class Progress: def __init__(self, msg, accuracy=0): self.p = console.progress() # if we don't show a decimal sign, the maximum width is 3 (max value is 100 (%)). Else its 3 + 1 (for the decimal point) # the accuracy that the user requested. total_width = 3 if accuracy == 0 else 4 + accuracy # sample formatting string: [%5.1f%%] for an accuracy of 1 self.percent_format = "[%%%d.%df%%%%]" % (total_width, accuracy) self.msg = msg def __call__(self, bytes_read, bytes_total): if bytes_total: completed = bytes_read / bytes_total total_as_mb = convert.bytes_to_human_string(bytes_total) self.p.print("%s (%s total size)" % (self.msg, total_as_mb), self.percent_format % (completed * 100)) else: self.p.print(self.msg, ".") def finish(self): self.p.finish() def _fake_import_boto3(): # This function only exists to be mocked in tests to raise an ImportError, in # order to simulate the absence of boto3 pass def _download_from_s3_bucket(bucket_name, bucket_path, local_path, expected_size_in_bytes=None, progress_indicator=None): # pylint: disable=import-outside-toplevel # lazily initialize S3 support - it might not be available try: _fake_import_boto3() import boto3.s3.transfer except ImportError: console.error("S3 support is optional. Install it with `python -m pip install esrally[s3]`") raise class S3ProgressAdapter: def __init__(self, size, progress): self._expected_size_in_bytes = size self._progress = progress self._bytes_read = 0 def __call__(self, bytes_amount): self._bytes_read += bytes_amount self._progress(self._bytes_read, self._expected_size_in_bytes) s3 = boto3.resource("s3") bucket = s3.Bucket(bucket_name) if expected_size_in_bytes is None: expected_size_in_bytes = bucket.Object(bucket_path).content_length progress_callback = S3ProgressAdapter(expected_size_in_bytes, progress_indicator) if progress_indicator else None bucket.download_file(bucket_path, local_path, Callback=progress_callback, Config=boto3.s3.transfer.TransferConfig(use_threads=False)) def _build_gcs_object_url(bucket_name, bucket_path): # / and other special characters must be urlencoded in bucket and object names # ref: https://cloud.google.com/storage/docs/request-endpoints#encoding return functools.reduce( urllib.parse.urljoin, [ "https://storage.googleapis.com/storage/v1/b/", f"{quote(bucket_name.strip('/'), safe='')}/", "o/", f"{quote(bucket_path.strip('/'), safe='')}", "?alt=media", ], ) def _download_from_gcs_bucket(bucket_name, bucket_path, local_path, expected_size_in_bytes=None, progress_indicator=None): # pylint: disable=import-outside-toplevel # lazily initialize Google Cloud Storage support - we might not need it import google.auth import google.auth.transport.requests as tr_requests import google.oauth2.credentials # Using Google Resumable Media as the standard storage library doesn't support progress # (https://github.com/googleapis/python-storage/issues/27) from google.resumable_media.requests import ChunkedDownload ro_scope = "https://www.googleapis.com/auth/devstorage.read_only" access_token = os.environ.get("GOOGLE_AUTH_TOKEN") if access_token: credentials = google.oauth2.credentials.Credentials(token=access_token, scopes=(ro_scope,)) else: # https://google-auth.readthedocs.io/en/latest/user-guide.html credentials, _ = google.auth.default(scopes=(ro_scope,)) transport = tr_requests.AuthorizedSession(credentials) chunk_size = 50 * 1024 * 1024 # 50MB with open(local_path, "wb") as local_fp: media_url = _build_gcs_object_url(bucket_name, bucket_path) download = ChunkedDownload(media_url, chunk_size, local_fp) # allow us to calculate the total bytes download.consume_next_chunk(transport) if not expected_size_in_bytes: expected_size_in_bytes = download.total_bytes while not download.finished: if progress_indicator and download.bytes_downloaded and download.total_bytes: progress_indicator(download.bytes_downloaded, expected_size_in_bytes) download.consume_next_chunk(transport) # show final progress (for large files) or any progress (for files < chunk_size) if progress_indicator and download.bytes_downloaded and expected_size_in_bytes: progress_indicator(download.bytes_downloaded, expected_size_in_bytes) def download_from_bucket(blobstore, url, local_path, expected_size_in_bytes=None, progress_indicator=None): blob_downloader = {"s3": _download_from_s3_bucket, "gs": _download_from_gcs_bucket} logger = logging.getLogger(__name__) bucket_and_path = url[5:] # s3:// or gs:// prefix for now bucket_end_index = bucket_and_path.find("/") bucket = bucket_and_path[:bucket_end_index] # we need to remove the leading "/" bucket_path = bucket_and_path[bucket_end_index + 1 :] logger.info("Downloading from [%s] bucket [%s] and path [%s] to [%s].", blobstore, bucket, bucket_path, local_path) blob_downloader[blobstore](bucket, bucket_path, local_path, expected_size_in_bytes, progress_indicator) return expected_size_in_bytes HTTP_DOWNLOAD_RETRIES = 10 def _download_http(url, local_path, expected_size_in_bytes=None, progress_indicator=None): with ( _request( "GET", url, preload_content=False, enforce_content_length=True, retries=10, timeout=urllib3.Timeout(connect=45, read=240) ) as r, open(local_path, "wb") as out_file, ): if r.status > 299: raise urllib.error.HTTPError( url, r.status, "", None, # type: ignore[arg-type] # TODO remove the below ignore when introducing type hints None, ) try: size_from_content_header = int(r.getheader("Content-Length", "")) if expected_size_in_bytes is None: expected_size_in_bytes = size_from_content_header except ValueError: size_from_content_header = None chunk_size = 2**16 bytes_read = 0 for chunk in r.stream(chunk_size): out_file.write(chunk) bytes_read += len(chunk) if progress_indicator and size_from_content_header: progress_indicator(bytes_read, size_from_content_header) return expected_size_in_bytes def download_http(url, local_path, expected_size_in_bytes=None, progress_indicator=None, *, sleep=time.sleep): logger = logging.getLogger(__name__) for i in range(HTTP_DOWNLOAD_RETRIES + 1): try: return _download_http(url, local_path, expected_size_in_bytes, progress_indicator) except (urllib3.exceptions.ProtocolError, urllib3.exceptions.ReadTimeoutError) as exc: if i == HTTP_DOWNLOAD_RETRIES: raise logger.warning("Retrying after %s", exc) sleep(5) continue def add_url_param_elastic_no_kpi(url): scheme = urllib3.util.parse_url(url).scheme if scheme.startswith("http"): return _add_url_param(url, {"x-elastic-no-kpi": "true"}) else: return url def _add_url_param(url, params): url_parsed = urlparse(url) query = parse_qs(url_parsed.query) query.update(params) return urlunparse( (url_parsed.scheme, url_parsed.netloc, url_parsed.path, url_parsed.params, urlencode(query, doseq=True), url_parsed.fragment) ) def download(url, local_path, expected_size_in_bytes=None, progress_indicator=None): """ Downloads a single file from a URL to the provided local path. :param url: The remote URL specifying one file that should be downloaded. May be either a HTTP, HTTPS, S3 or GS URL. :param local_path: The local file name of the file that should be downloaded. :param expected_size_in_bytes: The expected file size in bytes if known. It will be used to verify that all data have been downloaded. :param progress_indicator A callable that can be use to report progress to the user. It is expected to take two parameters ``bytes_read`` and ``total_bytes``. If not provided, no progress is shown. Note that ``total_bytes`` is derived from the ``Content-Length`` header and not from the parameter ``expected_size_in_bytes`` for downloads via HTTP(S). """ tmp_data_set_path = local_path + ".tmp" try: scheme = urllib3.util.parse_url(url).scheme if scheme in ["s3", "gs"]: expected_size_in_bytes = download_from_bucket(scheme, url, tmp_data_set_path, expected_size_in_bytes, progress_indicator) else: expected_size_in_bytes = download_http(url, tmp_data_set_path, expected_size_in_bytes, progress_indicator) except BaseException: if os.path.isfile(tmp_data_set_path): os.remove(tmp_data_set_path) raise download_size = os.path.getsize(tmp_data_set_path) if expected_size_in_bytes is not None and download_size != expected_size_in_bytes: if os.path.isfile(tmp_data_set_path): os.remove(tmp_data_set_path) raise exceptions.DataError( "Download of [%s] is corrupt. Downloaded [%d] bytes but [%d] bytes are expected. Please retry." % (local_path, download_size, expected_size_in_bytes) ) os.rename(tmp_data_set_path, local_path) def retrieve_content_as_string(url): with _request("GET", url, timeout=urllib3.Timeout(connect=45, read=240)) as response: return response.read().decode("utf-8") def _request(method, url, **kwargs): if not _HTTP or not _HTTPS: init() parsed_url = urllib3.util.parse_url(url) manager = _HTTPS if parsed_url.scheme == "https" else _HTTP return manager.request(method, url, **kwargs) def resolve(hostname_or_ip): if hostname_or_ip and hostname_or_ip.startswith("127"): return hostname_or_ip addrinfo = socket.getaddrinfo(hostname_or_ip, 22, 0, 0, socket.IPPROTO_TCP) for family, _, _, _, sockaddr in addrinfo: # we're interested in the IPv4 address if family == socket.AddressFamily.AF_INET: ip, _ = sockaddr if ip[:3] != "127": return ip return None