odps/accounts.py (455 lines of code) (raw):
# Copyright 1999-2025 Alibaba Group Holding Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A couple of authentication types in ODPS."""
import base64
import calendar
import hashlib
import hmac
import json
import logging
import os
import threading
import time
from collections import OrderedDict
from datetime import datetime
import requests
from . import options, utils
from .compat import cgi, datetime_utcnow, parse_qsl, six, unquote, urlparse
logger = logging.getLogger(__name__)
DEFAULT_TEMP_ACCOUNT_HOURS = 5
class BaseAccount(object):
def _build_canonical_str(self, url_components, req):
# Build signing string
lines = [req.method]
headers_to_sign = dict()
canonical_resource = url_components.path
params = dict()
if url_components.query:
params_list = sorted(
parse_qsl(url_components.query, True), key=lambda it: it[0]
)
assert len(params_list) == len(set(it[0] for it in params_list))
params = dict(params_list)
convert = lambda kv: kv if kv[1] != "" else (kv[0],)
params_str = "&".join(["=".join(convert(kv)) for kv in params_list])
canonical_resource = "%s?%s" % (canonical_resource, params_str)
headers = req.headers
logger.debug("headers before signing: %s", headers)
for k, v in six.iteritems(headers):
k = k.lower()
if k in ("content-type", "content-md5") or k.startswith("x-odps"):
headers_to_sign[k] = v
for k in ("content-type", "content-md5"):
if k not in headers_to_sign:
headers_to_sign[k] = ""
date_str = headers.get("Date")
if not date_str:
req_date = utils.formatdate(usegmt=True)
headers["Date"] = req_date
date_str = req_date
headers_to_sign["date"] = date_str
for param_key, param_value in six.iteritems(params):
if param_key.startswith("x-odps-"):
headers_to_sign[param_key] = param_value
headers_to_sign = OrderedDict(
[(k, headers_to_sign[k]) for k in sorted(headers_to_sign)]
)
logger.debug("headers to sign: %s", headers_to_sign)
for k, v in six.iteritems(headers_to_sign):
if k.startswith("x-odps-"):
lines.append("%s:%s" % (k, v))
else:
lines.append(v)
lines.append(canonical_resource)
return "\n".join(lines)
def sign_request(self, req, endpoint, region_name=None):
raise NotImplementedError
class AliyunAccount(BaseAccount):
"""
Account of aliyun.com
"""
def __init__(self, access_id, secret_access_key):
self.access_id = access_id
self.secret_access_key = secret_access_key
self._last_signature_date = None
self._last_signature_key = None
def _get_v4_signature_key(self, date_str, region_name):
if date_str == self._last_signature_date:
return self._last_signature_key
k_secret = utils.to_binary("aliyun_v4" + self.secret_access_key)
k_date = hmac.new(k_secret, utils.to_binary(date_str), hashlib.sha256).digest()
k_region = hmac.new(
k_date, utils.to_binary(region_name), hashlib.sha256
).digest()
k_service = hmac.new(k_region, b"odps", hashlib.sha256).digest()
self._last_signature_date = date_str
self._last_signature_key = hmac.new(
k_service, b"aliyun_v4_request", hashlib.sha256
).digest()
return self._last_signature_key
def calc_auth_str(self, canonical_str, region_name=None):
if region_name is None:
# use legacy v2 sign
signature = base64.b64encode(
hmac.new(
utils.to_binary(self.secret_access_key),
utils.to_binary(canonical_str),
hashlib.sha1,
).digest()
)
return "ODPS %s:%s" % (self.access_id, utils.to_str(signature))
else:
# use v4 sign
date_str = datetime.strftime(datetime_utcnow(), "%Y%m%d")
credential = "/".join(
[self.access_id, date_str, region_name, "odps/aliyun_v4_request"]
)
sign_key = self._get_v4_signature_key(date_str, region_name)
signature = base64.b64encode(
hmac.new(
sign_key, utils.to_binary(canonical_str), hashlib.sha1
).digest()
)
return "ODPS %s:%s" % (credential, utils.to_str(signature))
def sign_request(self, req, endpoint, region_name=None):
url = req.url[len(endpoint) :]
url_components = urlparse(unquote(url), allow_fragments=False)
canonical_str = self._build_canonical_str(url_components, req)
logger.debug("canonical string: %s", canonical_str)
req.headers["Authorization"] = self.calc_auth_str(canonical_str, region_name)
logger.debug("headers after signing: %r", req.headers)
class AppAccount(BaseAccount):
"""
Account for applications.
"""
def __init__(self, access_id, secret_access_key):
self.access_id = access_id
self.secret_access_key = secret_access_key
def sign_request(self, req, endpoint, region_name=None):
auth_str = req.headers["Authorization"]
signature = base64.b64encode(
hmac.new(
utils.to_binary(self.secret_access_key),
utils.to_binary(auth_str),
hashlib.sha1,
).digest()
)
app_auth_str = (
"account_provider:%s,signature_method:%s,access_id:%s,signature:%s"
% ("aliyun", "hmac-sha1", self.access_id, utils.to_str(signature))
)
req.headers["application-authentication"] = app_auth_str
logger.debug("headers after app signing: %r", req.headers)
class SignServer(object):
class SignServerHandler(six.moves.BaseHTTPServer.BaseHTTPRequestHandler):
def do_GET(self):
self.send_response(200)
self.send_header("Content-type", "text/plain")
self.end_headers()
self.wfile.write(b"PyODPS Account Server")
def do_POST(self):
try:
self._do_POST()
except:
logger.exception("Failed to sign request on SignServer.")
self.send_response(500)
self.end_headers()
def _do_POST(self):
ctype, pdict = cgi.parse_header(self.headers.get("content-type"))
if ctype == "multipart/form-data":
postvars = cgi.parse_multipart(self.rfile, pdict)
elif ctype == "application/x-www-form-urlencoded":
length = int(self.headers.get("content-length"))
postvars = six.moves.urllib.parse.parse_qs(
self.rfile.read(length), keep_blank_values=1
)
else:
self.send_response(400)
self.end_headers()
return
self._sign(postvars)
def _sign(self, postvars):
if self.server._token is not None:
auth = self.headers.get("Authorization")
if not auth:
self.send_response(401)
self.end_headers()
return
method, content = auth.split(" ", 1)
method = method.lower()
if method == "token":
if content != self.server._token:
self.send_response(401)
self.end_headers()
return
else:
self.send_response(401)
self.end_headers()
return
assert len(postvars[b"access_id"]) == 1 and len(postvars[b"canonical"]) == 1
access_id = utils.to_str(postvars[b"access_id"][0])
canonical = utils.to_str(postvars[b"canonical"][0])
if b"region_name" not in postvars:
region_name = None
else:
region_name = utils.to_str(postvars[b"region_name"][0])
secret_access_key = self.server._accounts[access_id]
account = AliyunAccount(access_id, secret_access_key)
auth_str = account.calc_auth_str(canonical, region_name)
self.send_response(200)
self.send_header("Content-Type", "text/json")
self.end_headers()
self.wfile.write(utils.to_binary(auth_str))
def log_message(self, *args):
return
class SignServerCore(
six.moves.socketserver.ThreadingMixIn, six.moves.BaseHTTPServer.HTTPServer
):
def __init__(self, *args, **kwargs):
self._accounts = kwargs.pop("accounts", {})
self._token = kwargs.pop("token", None)
self._ready = False
six.moves.BaseHTTPServer.HTTPServer.__init__(self, *args, **kwargs)
self._ready = True
def stop(self):
self.shutdown()
self.server_close()
def __init__(self, token=None):
self._server = None
self._accounts = dict()
self._token = token
@property
def server(self):
return self._server
@property
def accounts(self):
return self._accounts
@property
def token(self):
return self._token
def start(self, endpoint):
def starter():
self._server = self.SignServerCore(
endpoint,
self.SignServerHandler,
accounts=self.accounts,
token=self.token,
)
self._server.serve_forever()
thread = threading.Thread(target=starter)
thread.daemon = True
thread.start()
while self._server is None or not self._server._ready:
time.sleep(0.05)
def stop(self):
self._server.stop()
class SignServerError(Exception):
def __init__(self, msg, code, content):
super(SignServerError, self).__init__(msg)
self.code = code
self.content = content
class SignServerAccount(BaseAccount):
_session_local = threading.local()
def __init__(
self, access_id, sign_endpoint=None, server=None, port=None, token=None
):
self.access_id = access_id
self.sign_endpoint = sign_endpoint or (server, port)
self.token = token
@property
def session(self):
if not hasattr(type(self)._session_local, "_session"):
adapter_options = dict(
pool_connections=options.pool_connections,
pool_maxsize=options.pool_maxsize,
max_retries=options.retry_times,
)
session = requests.Session()
# mount adapters with retry times
session.mount("http://", requests.adapters.HTTPAdapter(**adapter_options))
session.mount("https://", requests.adapters.HTTPAdapter(**adapter_options))
self._session_local._session = session
return self._session_local._session
def sign_request(self, req, endpoint, region_name=None):
url = req.url[len(endpoint) :]
url_components = urlparse(unquote(url), allow_fragments=False)
canonical_str = self._build_canonical_str(url_components, req)
logger.debug("canonical string: %s", canonical_str)
headers = dict()
if self.token:
headers["Authorization"] = "token " + self.token
sign_content = dict(access_id=self.access_id, canonical=canonical_str)
if region_name is not None:
sign_content["region_name"] = region_name
resp = self.session.request(
"post",
"http://%s:%s" % self.sign_endpoint,
headers=headers,
data=sign_content,
)
if resp.status_code < 400:
req.headers["Authorization"] = resp.text
logger.debug("headers after signing: %r", req.headers)
else:
try:
err_msg = resp_err = resp.text
except:
resp_err = resp.content
err_msg = repr(resp_err)
raise SignServerError(
"Sign server returned error code: %d\n%s" % (resp.status_code, err_msg),
resp.status_code,
resp_err,
)
class TempAccountMixin(object):
def __init__(self, expired_hours=DEFAULT_TEMP_ACCOUNT_HOURS):
self._last_refresh_time = time.time()
if expired_hours is not None:
self._expire_seconds = expired_hours * 3600
self._expire_time = self._last_refresh_time + self._expire_seconds
else:
self._expire_time = self._expire_seconds = None
self.reload()
def _is_account_valid(self):
raise NotImplementedError
def _reload_account(self):
raise NotImplementedError
def _need_update(self):
if not self._is_account_valid():
return True
if self._expire_time is not None and self._expire_seconds is not None:
min_exp_time = min(
self._expire_time, self._last_refresh_time + self._expire_seconds
)
return time.time() > min_exp_time
return False
def reload(self, force=False):
t = time.time()
if force or self._need_update():
self._last_refresh_time = t
default_expire = t + (
self._expire_seconds or 3600 * DEFAULT_TEMP_ACCOUNT_HOURS
)
self._expire_time = self._reload_account() or default_expire
class StsAccount(TempAccountMixin, AliyunAccount):
"""
Account of sts
"""
def __init__(
self,
access_id,
secret_access_key,
sts_token,
expired_hours=DEFAULT_TEMP_ACCOUNT_HOURS,
):
self.sts_token = sts_token
AliyunAccount.__init__(self, access_id, secret_access_key)
TempAccountMixin.__init__(self, expired_hours=expired_hours)
@classmethod
def from_environments(cls):
expired_hours = int(
os.getenv("ODPS_STS_TOKEN_HOURS", str(DEFAULT_TEMP_ACCOUNT_HOURS))
)
if "ODPS_STS_ACCOUNT_FILE" in os.environ or "ODPS_STS_TOKEN" in os.environ:
if "ODPS_STS_ACCOUNT_FILE" not in os.environ:
expired_hours = None
return cls(None, None, None, expired_hours=expired_hours)
return None
def sign_request(self, req, endpoint, region_name=None):
self.reload()
super(StsAccount, self).sign_request(req, endpoint, region_name=region_name)
if self.sts_token:
req.headers["authorization-sts-token"] = self.sts_token
if self._last_refresh_time:
req.headers["x-pyodps-token-timestamp"] = str(self._last_refresh_time)
def _is_account_valid(self):
return self.sts_token is not None
def _resolve_expiration(self, exp_data):
if exp_data is None or self._expire_seconds is None:
return None
try:
return calendar.timegm(time.strptime(exp_data, "%Y-%m-%dT%H:%M:%SZ"))
except:
return None
def _reload_account(self):
ts = None
if "ODPS_STS_ACCOUNT_FILE" in os.environ:
token_file_name = os.getenv("ODPS_STS_ACCOUNT_FILE")
if token_file_name and os.path.exists(token_file_name):
with open(token_file_name, "r") as token_file:
token_json = json.load(token_file)
self.access_id = token_json["accessKeyId"]
self.secret_access_key = token_json["accessKeySecret"]
self.sts_token = token_json["securityToken"]
ts = self._resolve_expiration(token_json.get("expiration"))
logger.info("STS token reloaded: %s", self.sts_token)
elif "ODPS_STS_ACCESS_KEY_ID" in os.environ:
self.access_id = os.getenv("ODPS_STS_ACCESS_KEY_ID")
self.secret_access_key = os.getenv("ODPS_STS_ACCESS_KEY_SECRET")
self.sts_token = os.getenv("ODPS_STS_TOKEN")
logger.info("STS token reloaded: %s", self.sts_token)
return ts if ts is not None else None
class BearerTokenAccount(TempAccountMixin, BaseAccount):
def __init__(
self,
token=None,
expired_hours=DEFAULT_TEMP_ACCOUNT_HOURS,
get_bearer_token_fun=None,
):
self.token = token
self._custom_bearer_token_func = get_bearer_token_fun
TempAccountMixin.__init__(self, expired_hours=expired_hours)
@classmethod
def from_environments(cls):
expired_hours = int(
os.getenv("ODPS_BEARER_TOKEN_HOURS", str(DEFAULT_TEMP_ACCOUNT_HOURS))
)
kwargs = {"expired_hours": expired_hours}
if "ODPS_BEARER_TOKEN_FILE" in os.environ:
return cls(**kwargs)
elif "ODPS_BEARER_TOKEN" in os.environ:
kwargs["expired_hours"] = None
return cls(os.environ["ODPS_BEARER_TOKEN"], **kwargs)
return None
def _get_bearer_token(self):
if self._custom_bearer_token_func is not None:
return self._custom_bearer_token_func()
token_file_name = os.getenv("ODPS_BEARER_TOKEN_FILE")
if token_file_name and os.path.exists(token_file_name):
with open(token_file_name, "r") as token_file:
return token_file.read().strip()
else: # pragma: no cover
from cupid.runtime import RuntimeContext, context
if not RuntimeContext.is_context_ready():
return
cupid_context = context()
return cupid_context.get_bearer_token()
def _is_account_valid(self):
return self.token is not None
def _reload_account(self):
token = self._get_bearer_token()
logger.info("Bearer token reloaded: %s", token)
self.token = token
try:
resolved_token_parts = base64.b64decode(token).decode().split(",")
return int(resolved_token_parts[2])
except:
return None
def sign_request(self, req, endpoint, region_name=None):
self.reload()
url = req.url[len(endpoint) :]
url_components = urlparse(unquote(url), allow_fragments=False)
self._build_canonical_str(url_components, req)
if self.token is None:
raise TypeError("Cannot sign request with None bearer token")
req.headers["x-odps-bearer-token"] = self.token
if self._last_refresh_time:
req.headers["x-pyodps-token-timestamp"] = str(self._last_refresh_time)
logger.debug("headers after signing: %r", req.headers)
class CredentialProviderAccount(StsAccount):
def __init__(self, credential_provider):
self.provider = credential_provider
super(CredentialProviderAccount, self).__init__(None, None, None)
def _refresh_credential(self):
try:
credential = self.provider.get_credential()
except:
credential = self.provider.get_credentials()
self.access_id = credential.get_access_key_id()
self.secret_access_key = credential.get_access_key_secret()
self.sts_token = credential.get_security_token()
def sign_request(self, req, endpoint, region_name=None):
utils.call_with_retry(self._refresh_credential)
return super(CredentialProviderAccount, self).sign_request(
req, endpoint, region_name=region_name
)
def from_environments():
for account_cls in (StsAccount, BearerTokenAccount):
account = account_cls.from_environments()
if account is not None:
break
return account