odps/rest.py (313 lines of code) (raw):

# Copyright 1999-2025 Alibaba Group Holding Ltd. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Restful client enhanced by URL building and request signing facilities. """ from __future__ import absolute_import import json import logging import os import platform import re import threading from string import Template import requests try: from requests import ConnectTimeout except ImportError: from requests import Timeout as ConnectTimeout try: import requests_unixsocket except ImportError: requests_unixsocket = None from . import __version__, errors, utils from .compat import six, urlparse from .config import options from .utils import clear_survey_calls, get_package_version, get_survey_calls try: import requests.packages.urllib3.util.ssl_ requests.packages.urllib3.util.ssl_.DEFAULT_CIPHERS = "ALL" requests.packages.urllib3.disable_warnings() except ImportError: pass try: import urllib3.util.ssl_ urllib3.util.ssl_.DEFAULT_CIPHERS = "ALL" urllib3.disable_warnings() except ImportError: pass try: from urllib3.util import Retry except ImportError: try: from requests.packages.urllib3.util import Retry except ImportError: Retry = None logger = logging.getLogger(__name__) _default_user_agent = None _v4_sign_fallback_msgs = [ "need ak v3 support", "accesskey acl denied", ] def default_user_agent(): global _default_user_agent if _default_user_agent is not None: return _default_user_agent py_implementation = platform.python_implementation() py_version = platform.python_version() try: py_system = platform.system() py_release = platform.release() except IOError: py_system = "Unknown" py_release = "Unknown" ua_template = Template( options.user_agent_pattern or os.getenv("PYODPS_USER_AGENT_PATTERN") or "$pyodps_version $mars_version $maxframe_version $python_version $os_version" ) substitutes = dict( pyodps_version="%s/%s" % ("pyodps", __version__), python_version="%s/%s" % (py_implementation, py_version), os_version="%s/%s" % (py_system, py_release), mars_version="", maxframe_version="", ) try: from mars import __version__ as mars_version except: mars_version = None if mars_version: substitutes["mars_version"] = "%s/%s" % ("mars", mars_version) try: maxframe_version = get_package_version("maxframe") except: maxframe_version = None if maxframe_version: substitutes["maxframe_version"] = "%s/%s" % ("maxframe", maxframe_version) _default_user_agent = ua_template.safe_substitute(**substitutes) _default_user_agent = re.sub(" +", " ", _default_user_agent).strip() try: from .internal.rest import get_internal_user_agent_suffix _default_user_agent += " " + get_internal_user_agent_suffix() except: pass return _default_user_agent class RestClient(object): _session_local = threading.local() _endpoints_without_v4_sign = set() def __init__( self, account, endpoint, project=None, schema=None, user_agent=None, region_name=None, namespace=None, **kwargs ): if endpoint.endswith("/"): endpoint = endpoint[:-1] self._account = account self._endpoint = endpoint self._region_name = region_name self._user_agent = user_agent or default_user_agent() self.project = project self.schema = schema self.namespace = namespace self._proxy = kwargs.get("proxy") self._app_account = kwargs.get("app_account") self._tag = kwargs.get("tag") if isinstance(self._proxy, six.string_types): self._proxy = dict(http=self._proxy, https=self._proxy) @property def endpoint(self): return self._endpoint @property def account(self): return self._account @property def app_account(self): return self._app_account @property def region_name(self): return self._region_name @property def session(self): try: session_cache = type(self)._session_local.session_cache except AttributeError: session_cache = type(self)._session_local.session_cache = dict() try: return session_cache[self._endpoint] except KeyError: pass try: retries = Retry( total=options.retry_times, backoff_factor=0.1, allowed_methods={"DELETE", "GET", "HEAD"}, status_forcelist=[502, 503, 504], ) except: retries = options.retry_times parsed_url = urlparse(self._endpoint) adapter_options = dict( pool_connections=options.pool_connections, pool_maxsize=options.pool_maxsize, max_retries=retries, ) if parsed_url.scheme == "http+unix": session = requests_unixsocket.Session() session.mount( "http+unix://", requests_unixsocket.adapters.UnixAdapter(**adapter_options), ) else: session = requests.Session() # mount adapters with retry times session.mount("http://", requests.adapters.HTTPAdapter(**adapter_options)) session.mount("https://", requests.adapters.HTTPAdapter(**adapter_options)) session_cache[self._endpoint] = session return session def request(self, url, method, stream=False, **kwargs): sign_region_name = kwargs.get("region_name") or self._region_name if ( self._endpoint in self._endpoints_without_v4_sign or not options.enable_v4_sign ): sign_region_name = None auth_expire_retried = False while True: kwargs["region_name"] = sign_region_name try: return self._request(url, method, stream=stream, **kwargs) except errors.InternalServerError as ex: ex_msg = str(ex).lower() if sign_region_name is None or all( msg not in ex_msg for msg in _v4_sign_fallback_msgs ): raise logger.info( "Fallback of V4 signature for %s. Error message: %s", url, ex ) self._endpoints_without_v4_sign.add(self._endpoint) sign_region_name = None except errors.InvalidParameter as ex: if sign_region_name is None or "ODPS-0410051" not in str(ex): # Invalid credentials error not received from server raise logger.info( "Fallback of V4 signature for %s. Error message: %s", url, ex ) self._endpoints_without_v4_sign.add(self._endpoint) sign_region_name = None except errors.AuthorizationRequired as ex: if sign_region_name is None or "invalid or missing" not in str(ex): raise logger.info( "Fallback of V4 signature for %s. Error message: %s", url, ex ) self._endpoints_without_v4_sign.add(self._endpoint) sign_region_name = None except errors.AuthenticationRequestExpired: if not hasattr(self.account, "reload") or auth_expire_retried: raise logger.info( "AuthenticationRequestExpired encountered with %r. " "Will retry with reloaded account.", self.account, ) self.account.reload(True) auth_expire_retried = True def _request(self, url, method, stream=False, **kwargs): self.upload_survey_log() region_name = kwargs.pop("region_name", None) logger.debug("Start request.") logger.debug("%s: %s", method.upper(), url) if logger.getEffectiveLevel() <= logging.DEBUG: for k, v in kwargs.items(): logger.debug("%s: %s", k, v) # Construct user agent without handling the letter case. headers = kwargs.get("headers", {}) headers = {k: str(v) for k, v in six.iteritems(headers)} headers["User-Agent"] = self._user_agent if self.namespace: headers["x-odps-namespace-id"] = self.namespace kwargs["headers"] = headers params = kwargs.setdefault("params", {}) actions = kwargs.pop("actions", None) or kwargs.pop("action", None) or [] if isinstance(actions, six.string_types): actions = [actions] if actions: separator = "?" if "?" not in url else "&" url += separator + "&".join(actions) curr_project = kwargs.pop("curr_project", None) or self.project if "curr_project" not in params and curr_project is not None: params["curr_project"] = curr_project curr_schema = kwargs.pop("curr_schema", None) or self.schema if "curr_schema" not in params and curr_schema is not None: params["curr_schema"] = curr_schema timeout = kwargs.pop("timeout", None) req = requests.Request(method, url, **kwargs) prepared_req = req.prepare() logger.debug("request url + params %s", prepared_req.path_url) prepared_req.headers.pop("Authorization", None) prepared_req.headers.pop("application-authentication", None) self._account.sign_request( prepared_req, self._endpoint, region_name=region_name ) if getattr(self, "_app_account", None) is not None: self._app_account.sign_request( prepared_req, self._endpoint, region_name=region_name ) if any(v is None for v in prepared_req.headers.values()): none_headers = [k for k, v in prepared_req.headers.items() if v is None] raise TypeError( "Value of headers %s cannot be None" % ", ".join(none_headers) ) try: res = self.session.send( prepared_req, stream=stream, timeout=timeout or (options.connect_timeout, options.read_timeout), verify=options.verify_ssl, proxies=self._proxy, ) except ConnectTimeout: raise errors.ConnectTimeout( "Connecting to endpoint %s timeout." % self._endpoint ) logger.debug("response.status_code %d", res.status_code) logger.debug("response.headers: \n%s", res.headers) if not stream: logger.debug("response.content: %s\n", res.content) # Automatically detect error if not self.is_ok(res): errors.throw_if_parsable(res, self._endpoint, self._tag) return res def get(self, url, stream=False, **kwargs): return self.request(url, "get", stream=stream, **kwargs) def post(self, url, data=None, **kwargs): data = ( utils.to_binary(data, encoding="utf-8") if isinstance(data, six.string_types) else data ) return self.request(url, "post", data=data, **kwargs) def put(self, url, data=None, **kwargs): data = utils.to_binary(data) if isinstance(data, six.string_types) else data return self.request(url, "put", data=data, **kwargs) def head(self, url, **kwargs): return self.request(url, "head", **kwargs) def delete(self, url, **kwargs): return self.request(url, "delete", **kwargs) def upload_survey_log(self): try: from .models.core import RestModel survey = get_survey_calls() clear_survey_calls() if not survey: return if self.project is None: return url = "/".join( [self.endpoint, "projects", RestModel._encode(self.project), "logs"] ) self.put(url, json.dumps(survey)) except: pass # Misc helper methods def is_ok(self, resp): return resp.ok