esrally/utils/net.py (213 lines of code) (raw):
# Licensed to Elasticsearch B.V. under one or more contributor
# license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright
# ownership. Elasticsearch B.V. licenses this file to you under
# the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import functools
import logging
import os
import socket
import time
import urllib.error
from urllib.parse import parse_qs, quote, urlencode, urlparse, urlunparse
import certifi
import urllib3
from esrally import exceptions
from esrally.utils import console, convert
_HTTP = None
_HTTPS = None
def __proxy_manager_from_env(env_var, logger):
proxy_url = os.getenv(env_var.lower()) or os.getenv(env_var.upper())
if not proxy_url:
env_var = "all_proxy"
proxy_url = os.getenv(env_var) or os.getenv(env_var.upper())
if proxy_url and len(proxy_url) > 0:
parsed_url = urllib3.util.parse_url(proxy_url)
logger.info("Connecting via proxy URL [%s] to the Internet (picked up from the environment variable [%s]).", proxy_url, env_var)
return urllib3.ProxyManager(
proxy_url,
cert_reqs="CERT_REQUIRED",
ca_certs=certifi.where(),
# appropriate headers will only be set if there is auth info
proxy_headers=urllib3.make_headers(proxy_basic_auth=parsed_url.auth),
)
else:
logger.info("Connecting directly to the Internet (no proxy support) for [%s].", env_var)
return urllib3.PoolManager(cert_reqs="CERT_REQUIRED", ca_certs=certifi.where())
def init():
logger = logging.getLogger(__name__)
global _HTTP, _HTTPS
_HTTP = __proxy_manager_from_env("http_proxy", logger)
_HTTPS = __proxy_manager_from_env("https_proxy", logger)
class Progress:
def __init__(self, msg, accuracy=0):
self.p = console.progress()
# if we don't show a decimal sign, the maximum width is 3 (max value is 100 (%)). Else its 3 + 1 (for the decimal point)
# the accuracy that the user requested.
total_width = 3 if accuracy == 0 else 4 + accuracy
# sample formatting string: [%5.1f%%] for an accuracy of 1
self.percent_format = "[%%%d.%df%%%%]" % (total_width, accuracy)
self.msg = msg
def __call__(self, bytes_read, bytes_total):
if bytes_total:
completed = bytes_read / bytes_total
total_as_mb = convert.bytes_to_human_string(bytes_total)
self.p.print("%s (%s total size)" % (self.msg, total_as_mb), self.percent_format % (completed * 100))
else:
self.p.print(self.msg, ".")
def finish(self):
self.p.finish()
def _fake_import_boto3():
# This function only exists to be mocked in tests to raise an ImportError, in
# order to simulate the absence of boto3
pass
def _download_from_s3_bucket(bucket_name, bucket_path, local_path, expected_size_in_bytes=None, progress_indicator=None):
# pylint: disable=import-outside-toplevel
# lazily initialize S3 support - it might not be available
try:
_fake_import_boto3()
import boto3.s3.transfer
except ImportError:
console.error("S3 support is optional. Install it with `python -m pip install esrally[s3]`")
raise
class S3ProgressAdapter:
def __init__(self, size, progress):
self._expected_size_in_bytes = size
self._progress = progress
self._bytes_read = 0
def __call__(self, bytes_amount):
self._bytes_read += bytes_amount
self._progress(self._bytes_read, self._expected_size_in_bytes)
s3 = boto3.resource("s3")
bucket = s3.Bucket(bucket_name)
if expected_size_in_bytes is None:
expected_size_in_bytes = bucket.Object(bucket_path).content_length
progress_callback = S3ProgressAdapter(expected_size_in_bytes, progress_indicator) if progress_indicator else None
bucket.download_file(bucket_path, local_path, Callback=progress_callback, Config=boto3.s3.transfer.TransferConfig(use_threads=False))
def _build_gcs_object_url(bucket_name, bucket_path):
# / and other special characters must be urlencoded in bucket and object names
# ref: https://cloud.google.com/storage/docs/request-endpoints#encoding
return functools.reduce(
urllib.parse.urljoin,
[
"https://storage.googleapis.com/storage/v1/b/",
f"{quote(bucket_name.strip('/'), safe='')}/",
"o/",
f"{quote(bucket_path.strip('/'), safe='')}",
"?alt=media",
],
)
def _download_from_gcs_bucket(bucket_name, bucket_path, local_path, expected_size_in_bytes=None, progress_indicator=None):
# pylint: disable=import-outside-toplevel
# lazily initialize Google Cloud Storage support - we might not need it
import google.auth
import google.auth.transport.requests as tr_requests
import google.oauth2.credentials
# Using Google Resumable Media as the standard storage library doesn't support progress
# (https://github.com/googleapis/python-storage/issues/27)
from google.resumable_media.requests import ChunkedDownload
ro_scope = "https://www.googleapis.com/auth/devstorage.read_only"
access_token = os.environ.get("GOOGLE_AUTH_TOKEN")
if access_token:
credentials = google.oauth2.credentials.Credentials(token=access_token, scopes=(ro_scope,))
else:
# https://google-auth.readthedocs.io/en/latest/user-guide.html
credentials, _ = google.auth.default(scopes=(ro_scope,))
transport = tr_requests.AuthorizedSession(credentials)
chunk_size = 50 * 1024 * 1024 # 50MB
with open(local_path, "wb") as local_fp:
media_url = _build_gcs_object_url(bucket_name, bucket_path)
download = ChunkedDownload(media_url, chunk_size, local_fp)
# allow us to calculate the total bytes
download.consume_next_chunk(transport)
if not expected_size_in_bytes:
expected_size_in_bytes = download.total_bytes
while not download.finished:
if progress_indicator and download.bytes_downloaded and download.total_bytes:
progress_indicator(download.bytes_downloaded, expected_size_in_bytes)
download.consume_next_chunk(transport)
# show final progress (for large files) or any progress (for files < chunk_size)
if progress_indicator and download.bytes_downloaded and expected_size_in_bytes:
progress_indicator(download.bytes_downloaded, expected_size_in_bytes)
def download_from_bucket(blobstore, url, local_path, expected_size_in_bytes=None, progress_indicator=None):
blob_downloader = {"s3": _download_from_s3_bucket, "gs": _download_from_gcs_bucket}
logger = logging.getLogger(__name__)
bucket_and_path = url[5:] # s3:// or gs:// prefix for now
bucket_end_index = bucket_and_path.find("/")
bucket = bucket_and_path[:bucket_end_index]
# we need to remove the leading "/"
bucket_path = bucket_and_path[bucket_end_index + 1 :]
logger.info("Downloading from [%s] bucket [%s] and path [%s] to [%s].", blobstore, bucket, bucket_path, local_path)
blob_downloader[blobstore](bucket, bucket_path, local_path, expected_size_in_bytes, progress_indicator)
return expected_size_in_bytes
HTTP_DOWNLOAD_RETRIES = 10
def _download_http(url, local_path, expected_size_in_bytes=None, progress_indicator=None):
with (
_request(
"GET", url, preload_content=False, enforce_content_length=True, retries=10, timeout=urllib3.Timeout(connect=45, read=240)
) as r,
open(local_path, "wb") as out_file,
):
if r.status > 299:
raise urllib.error.HTTPError(
url,
r.status,
"",
None, # type: ignore[arg-type] # TODO remove the below ignore when introducing type hints
None,
)
try:
size_from_content_header = int(r.getheader("Content-Length", ""))
if expected_size_in_bytes is None:
expected_size_in_bytes = size_from_content_header
except ValueError:
size_from_content_header = None
chunk_size = 2**16
bytes_read = 0
for chunk in r.stream(chunk_size):
out_file.write(chunk)
bytes_read += len(chunk)
if progress_indicator and size_from_content_header:
progress_indicator(bytes_read, size_from_content_header)
return expected_size_in_bytes
def download_http(url, local_path, expected_size_in_bytes=None, progress_indicator=None, *, sleep=time.sleep):
logger = logging.getLogger(__name__)
for i in range(HTTP_DOWNLOAD_RETRIES + 1):
try:
return _download_http(url, local_path, expected_size_in_bytes, progress_indicator)
except (urllib3.exceptions.ProtocolError, urllib3.exceptions.ReadTimeoutError) as exc:
if i == HTTP_DOWNLOAD_RETRIES:
raise
logger.warning("Retrying after %s", exc)
sleep(5)
continue
def add_url_param_elastic_no_kpi(url):
scheme = urllib3.util.parse_url(url).scheme
if scheme.startswith("http"):
return _add_url_param(url, {"x-elastic-no-kpi": "true"})
else:
return url
def _add_url_param(url, params):
url_parsed = urlparse(url)
query = parse_qs(url_parsed.query)
query.update(params)
return urlunparse(
(url_parsed.scheme, url_parsed.netloc, url_parsed.path, url_parsed.params, urlencode(query, doseq=True), url_parsed.fragment)
)
def download(url, local_path, expected_size_in_bytes=None, progress_indicator=None):
"""
Downloads a single file from a URL to the provided local path.
:param url: The remote URL specifying one file that should be downloaded. May be either a HTTP, HTTPS, S3 or GS URL.
:param local_path: The local file name of the file that should be downloaded.
:param expected_size_in_bytes: The expected file size in bytes if known. It will be used to verify that all data have been downloaded.
:param progress_indicator A callable that can be use to report progress to the user. It is expected to take two parameters
``bytes_read`` and ``total_bytes``. If not provided, no progress is shown. Note that ``total_bytes`` is derived from
the ``Content-Length`` header and not from the parameter ``expected_size_in_bytes`` for downloads via HTTP(S).
"""
tmp_data_set_path = local_path + ".tmp"
try:
scheme = urllib3.util.parse_url(url).scheme
if scheme in ["s3", "gs"]:
expected_size_in_bytes = download_from_bucket(scheme, url, tmp_data_set_path, expected_size_in_bytes, progress_indicator)
else:
expected_size_in_bytes = download_http(url, tmp_data_set_path, expected_size_in_bytes, progress_indicator)
except BaseException:
if os.path.isfile(tmp_data_set_path):
os.remove(tmp_data_set_path)
raise
download_size = os.path.getsize(tmp_data_set_path)
if expected_size_in_bytes is not None and download_size != expected_size_in_bytes:
if os.path.isfile(tmp_data_set_path):
os.remove(tmp_data_set_path)
raise exceptions.DataError(
"Download of [%s] is corrupt. Downloaded [%d] bytes but [%d] bytes are expected. Please retry."
% (local_path, download_size, expected_size_in_bytes)
)
os.rename(tmp_data_set_path, local_path)
def retrieve_content_as_string(url):
with _request("GET", url, timeout=urllib3.Timeout(connect=45, read=240)) as response:
return response.read().decode("utf-8")
def _request(method, url, **kwargs):
if not _HTTP or not _HTTPS:
init()
parsed_url = urllib3.util.parse_url(url)
manager = _HTTPS if parsed_url.scheme == "https" else _HTTP
return manager.request(method, url, **kwargs)
def resolve(hostname_or_ip):
if hostname_or_ip and hostname_or_ip.startswith("127"):
return hostname_or_ip
addrinfo = socket.getaddrinfo(hostname_or_ip, 22, 0, 0, socket.IPPROTO_TCP)
for family, _, _, _, sockaddr in addrinfo:
# we're interested in the IPv4 address
if family == socket.AddressFamily.AF_INET:
ip, _ = sockaddr
if ip[:3] != "127":
return ip
return None