odps/rest.py (313 lines of code) (raw):
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Restful client enhanced by URL building and request signing facilities.
"""
from __future__ import absolute_import
import json
import logging
import os
import platform
import re
import threading
from string import Template
import requests
try:
from requests import ConnectTimeout
except ImportError:
from requests import Timeout as ConnectTimeout
try:
import requests_unixsocket
except ImportError:
requests_unixsocket = None
from . import __version__, errors, utils
from .compat import six, urlparse
from .config import options
from .utils import clear_survey_calls, get_package_version, get_survey_calls
try:
import requests.packages.urllib3.util.ssl_
requests.packages.urllib3.util.ssl_.DEFAULT_CIPHERS = "ALL"
requests.packages.urllib3.disable_warnings()
except ImportError:
pass
try:
import urllib3.util.ssl_
urllib3.util.ssl_.DEFAULT_CIPHERS = "ALL"
urllib3.disable_warnings()
except ImportError:
pass
try:
from urllib3.util import Retry
except ImportError:
try:
from requests.packages.urllib3.util import Retry
except ImportError:
Retry = None
logger = logging.getLogger(__name__)
_default_user_agent = None
_v4_sign_fallback_msgs = [
"need ak v3 support",
"accesskey acl denied",
]
def default_user_agent():
global _default_user_agent
if _default_user_agent is not None:
return _default_user_agent
py_implementation = platform.python_implementation()
py_version = platform.python_version()
try:
py_system = platform.system()
py_release = platform.release()
except IOError:
py_system = "Unknown"
py_release = "Unknown"
ua_template = Template(
options.user_agent_pattern
or os.getenv("PYODPS_USER_AGENT_PATTERN")
or "$pyodps_version $mars_version $maxframe_version $python_version $os_version"
)
substitutes = dict(
pyodps_version="%s/%s" % ("pyodps", __version__),
python_version="%s/%s" % (py_implementation, py_version),
os_version="%s/%s" % (py_system, py_release),
mars_version="",
maxframe_version="",
)
try:
from mars import __version__ as mars_version
except:
mars_version = None
if mars_version:
substitutes["mars_version"] = "%s/%s" % ("mars", mars_version)
try:
maxframe_version = get_package_version("maxframe")
except:
maxframe_version = None
if maxframe_version:
substitutes["maxframe_version"] = "%s/%s" % ("maxframe", maxframe_version)
_default_user_agent = ua_template.safe_substitute(**substitutes)
_default_user_agent = re.sub(" +", " ", _default_user_agent).strip()
try:
from .internal.rest import get_internal_user_agent_suffix
_default_user_agent += " " + get_internal_user_agent_suffix()
except:
pass
return _default_user_agent
class RestClient(object):
_session_local = threading.local()
_endpoints_without_v4_sign = set()
def __init__(
self,
account,
endpoint,
project=None,
schema=None,
user_agent=None,
region_name=None,
namespace=None,
**kwargs
):
if endpoint.endswith("/"):
endpoint = endpoint[:-1]
self._account = account
self._endpoint = endpoint
self._region_name = region_name
self._user_agent = user_agent or default_user_agent()
self.project = project
self.schema = schema
self.namespace = namespace
self._proxy = kwargs.get("proxy")
self._app_account = kwargs.get("app_account")
self._tag = kwargs.get("tag")
if isinstance(self._proxy, six.string_types):
self._proxy = dict(http=self._proxy, https=self._proxy)
@property
def endpoint(self):
return self._endpoint
@property
def account(self):
return self._account
@property
def app_account(self):
return self._app_account
@property
def region_name(self):
return self._region_name
@property
def session(self):
try:
session_cache = type(self)._session_local.session_cache
except AttributeError:
session_cache = type(self)._session_local.session_cache = dict()
try:
return session_cache[self._endpoint]
except KeyError:
pass
try:
retries = Retry(
total=options.retry_times,
backoff_factor=0.1,
allowed_methods={"DELETE", "GET", "HEAD"},
status_forcelist=[502, 503, 504],
)
except:
retries = options.retry_times
parsed_url = urlparse(self._endpoint)
adapter_options = dict(
pool_connections=options.pool_connections,
pool_maxsize=options.pool_maxsize,
max_retries=retries,
)
if parsed_url.scheme == "http+unix":
session = requests_unixsocket.Session()
session.mount(
"http+unix://",
requests_unixsocket.adapters.UnixAdapter(**adapter_options),
)
else:
session = requests.Session()
# mount adapters with retry times
session.mount("http://", requests.adapters.HTTPAdapter(**adapter_options))
session.mount("https://", requests.adapters.HTTPAdapter(**adapter_options))
session_cache[self._endpoint] = session
return session
def request(self, url, method, stream=False, **kwargs):
sign_region_name = kwargs.get("region_name") or self._region_name
if (
self._endpoint in self._endpoints_without_v4_sign
or not options.enable_v4_sign
):
sign_region_name = None
auth_expire_retried = False
while True:
kwargs["region_name"] = sign_region_name
try:
return self._request(url, method, stream=stream, **kwargs)
except errors.InternalServerError as ex:
ex_msg = str(ex).lower()
if sign_region_name is None or all(
msg not in ex_msg for msg in _v4_sign_fallback_msgs
):
raise
logger.info(
"Fallback of V4 signature for %s. Error message: %s", url, ex
)
self._endpoints_without_v4_sign.add(self._endpoint)
sign_region_name = None
except errors.InvalidParameter as ex:
if sign_region_name is None or "ODPS-0410051" not in str(ex):
# Invalid credentials error not received from server
raise
logger.info(
"Fallback of V4 signature for %s. Error message: %s", url, ex
)
self._endpoints_without_v4_sign.add(self._endpoint)
sign_region_name = None
except errors.AuthorizationRequired as ex:
if sign_region_name is None or "invalid or missing" not in str(ex):
raise
logger.info(
"Fallback of V4 signature for %s. Error message: %s", url, ex
)
self._endpoints_without_v4_sign.add(self._endpoint)
sign_region_name = None
except errors.AuthenticationRequestExpired:
if not hasattr(self.account, "reload") or auth_expire_retried:
raise
logger.info(
"AuthenticationRequestExpired encountered with %r. "
"Will retry with reloaded account.",
self.account,
)
self.account.reload(True)
auth_expire_retried = True
def _request(self, url, method, stream=False, **kwargs):
self.upload_survey_log()
region_name = kwargs.pop("region_name", None)
logger.debug("Start request.")
logger.debug("%s: %s", method.upper(), url)
if logger.getEffectiveLevel() <= logging.DEBUG:
for k, v in kwargs.items():
logger.debug("%s: %s", k, v)
# Construct user agent without handling the letter case.
headers = kwargs.get("headers", {})
headers = {k: str(v) for k, v in six.iteritems(headers)}
headers["User-Agent"] = self._user_agent
if self.namespace:
headers["x-odps-namespace-id"] = self.namespace
kwargs["headers"] = headers
params = kwargs.setdefault("params", {})
actions = kwargs.pop("actions", None) or kwargs.pop("action", None) or []
if isinstance(actions, six.string_types):
actions = [actions]
if actions:
separator = "?" if "?" not in url else "&"
url += separator + "&".join(actions)
curr_project = kwargs.pop("curr_project", None) or self.project
if "curr_project" not in params and curr_project is not None:
params["curr_project"] = curr_project
curr_schema = kwargs.pop("curr_schema", None) or self.schema
if "curr_schema" not in params and curr_schema is not None:
params["curr_schema"] = curr_schema
timeout = kwargs.pop("timeout", None)
req = requests.Request(method, url, **kwargs)
prepared_req = req.prepare()
logger.debug("request url + params %s", prepared_req.path_url)
prepared_req.headers.pop("Authorization", None)
prepared_req.headers.pop("application-authentication", None)
self._account.sign_request(
prepared_req, self._endpoint, region_name=region_name
)
if getattr(self, "_app_account", None) is not None:
self._app_account.sign_request(
prepared_req, self._endpoint, region_name=region_name
)
if any(v is None for v in prepared_req.headers.values()):
none_headers = [k for k, v in prepared_req.headers.items() if v is None]
raise TypeError(
"Value of headers %s cannot be None" % ", ".join(none_headers)
)
try:
res = self.session.send(
prepared_req,
stream=stream,
timeout=timeout or (options.connect_timeout, options.read_timeout),
verify=options.verify_ssl,
proxies=self._proxy,
)
except ConnectTimeout:
raise errors.ConnectTimeout(
"Connecting to endpoint %s timeout." % self._endpoint
)
logger.debug("response.status_code %d", res.status_code)
logger.debug("response.headers: \n%s", res.headers)
if not stream:
logger.debug("response.content: %s\n", res.content)
# Automatically detect error
if not self.is_ok(res):
errors.throw_if_parsable(res, self._endpoint, self._tag)
return res
def get(self, url, stream=False, **kwargs):
return self.request(url, "get", stream=stream, **kwargs)
def post(self, url, data=None, **kwargs):
data = (
utils.to_binary(data, encoding="utf-8")
if isinstance(data, six.string_types)
else data
)
return self.request(url, "post", data=data, **kwargs)
def put(self, url, data=None, **kwargs):
data = utils.to_binary(data) if isinstance(data, six.string_types) else data
return self.request(url, "put", data=data, **kwargs)
def head(self, url, **kwargs):
return self.request(url, "head", **kwargs)
def delete(self, url, **kwargs):
return self.request(url, "delete", **kwargs)
def upload_survey_log(self):
try:
from .models.core import RestModel
survey = get_survey_calls()
clear_survey_calls()
if not survey:
return
if self.project is None:
return
url = "/".join(
[self.endpoint, "projects", RestModel._encode(self.project), "logs"]
)
self.put(url, json.dumps(survey))
except:
pass
# Misc helper methods
def is_ok(self, resp):
return resp.ok