pai/common/oss_utils.py (285 lines of code) (raw):
# Copyright 2023 Alibaba, Inc. or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import glob
import os.path
import pathlib
import tarfile
import tempfile
from typing import Optional, Tuple, Union
from urllib.parse import parse_qs, urlparse
import oss2
from alibabacloud_credentials.client import Client as CredentialClient
from alibabacloud_credentials.models import Config as CredentialConfig
from oss2.credentials import Credentials, CredentialsProvider
from tqdm.autonotebook import tqdm
from .logging import get_logger
logger = get_logger(__name__)
class _ProgressCallbackTqdm(tqdm):
def __call__(self, consumed_bytes, total_bytes):
self.update(n=consumed_bytes - self.n)
def _upload_with_progress(
filename,
object_key,
oss_bucket: oss2.Bucket,
):
local_file_size = os.path.getsize(filename)
with _ProgressCallbackTqdm(
total=local_file_size,
unit="B",
unit_scale=True,
desc=f"Uploading file: {filename}",
) as pbar:
oss2.resumable_upload(
bucket=oss_bucket,
key=object_key,
filename=filename,
progress_callback=pbar,
num_threads=os.cpu_count() * 2,
)
# Mark the progress as completed.
pbar.update(n=pbar.total - pbar.n)
def _download_with_progress(
filename,
object_key,
oss_bucket: oss2.Bucket,
):
total = oss_bucket.get_object_meta(object_key).content_length
with _ProgressCallbackTqdm(
total=total,
unit="B",
unit_scale=True,
desc=f"Downloading file: {filename}",
) as pbar:
oss2.resumable_download(
bucket=oss_bucket,
key=object_key,
filename=filename,
progress_callback=pbar,
num_threads=os.cpu_count(),
)
# Mark the progress as completed.
pbar.update(n=pbar.total - pbar.n)
def is_oss_uri(uri: Union[str, bytes]) -> bool:
"""Determines whether the given uri is an OSS uri.
Args:
uri (Union[str, bytes]): A string in OSS URI schema:
oss://<bucket_name>[.endpoint]/<path/to/file>,
Returns:
bool: True if the given uri is an OSS uri, else False.
"""
return bool(uri and isinstance(uri, (str, bytes)) and str(uri).startswith("oss://"))
class OssUriObj(object):
"""A class that represents an OSS URI and provides some convenient methods."""
def __init__(self, uri: str):
"""Constructor for class OssUriObj.
Args:
uri (str): A string in OSS URI schema: oss://<bucket_name>[.endpoint]/<path/to/file>,
endpoint in uri is optional.
"""
if not uri.startswith("oss://"):
raise ValueError(
"Invalid OSS URI schema, please provide a string starts with 'oss://'"
)
bucket_name, object_key, endpoint, role_arn = self.parse(uri)
self.bucket_name = bucket_name
self.object_key = object_key
self.endpoint = endpoint
self.role_arn = role_arn
@classmethod
def from_bucket_key_endpoint(
cls, bucket_name: str, object_key: str, endpoint: Optional[str] = None
) -> "OssUriObj":
"""Initialize an OSSUri object from bucket_name, object_key and endpoint.
Args:
bucket_name (str): The name of the OSS bucket.
object_key (str): OSS object key/path.
endpoint (str, optional): Endpoint for the OSS bucket.
Returns:
OssUriObj: An OssUriObj instance represents the specified OSS object.
"""
# OSS object key could not contain leading slashes.
# Document: https://help.aliyun.com/document_detail/273129.html
if object_key.startswith("/"):
logger.warning(
"OSS object key should not contain leading slashes, the leading"
" slashes will be removed."
)
object_key = object_key.lstrip("/")
if endpoint:
if endpoint.startswith("http://"):
endpoint = endpoint.lstrip("http://")
elif endpoint.startswith("https://"):
endpoint = endpoint.lstrip("https://")
uri = f"oss://{bucket_name}.{endpoint}/{object_key}"
else:
uri = f"oss://{bucket_name}/{object_key}"
return OssUriObj(uri=uri)
@classmethod
def parse(cls, oss_uri: str) -> Tuple[str, str, str, str]:
"""Parse OSS uri string and returns a tuple of (bucket_name, object_key,
endpoint, role_arn).
Args:
oss_uri (str): A string in OSS Uri schema: oss://{bucket_name}.{endpoint}/{object_key}.
Returns:
Tuple: An tuple of [bucket_name, object_key, endpoint, role_arn].
"""
parsed_result = urlparse(oss_uri)
if parsed_result.scheme != "oss":
raise ValueError(
"require OSS uri('oss://[bucket_name]/[object_key]') but "
"given '{}'".format(oss_uri)
)
object_key = parsed_result.path
if object_key.startswith("/"):
object_key = object_key[1:]
query = parse_qs(parsed_result.query)
if "." in parsed_result.hostname:
bucket_name, endpoint = parsed_result.hostname.split(".", 1)
else:
bucket_name = parsed_result.hostname
# try to get OSS endpoint from url query.
if "endpoint" in query:
endpoint = query.get("endpoint")[0]
elif "host" in query:
endpoint = query.get("host")[0]
else:
endpoint = None
role_arn = query.get("role_arn")[0] if "role_arn" in query else None
return bucket_name, object_key, endpoint, role_arn
def get_uri_with_endpoint(self, endpoint: str = None) -> str:
"""Get an OSS uri string contains endpoint.
Args:
endpoint (str): Endpoint of the OSS bucket.
Returns:
str: An string in OSS uri schema contains endpoint.
"""
if not endpoint and not self.endpoint:
raise ValueError("Unknown endpoint for the OSS bucket.")
return "oss://{bucket_name}.{endpoint}/{object_key}".format(
bucket_name=self.bucket_name,
endpoint=endpoint or self.endpoint,
object_key=self.object_key,
)
def get_dir_uri(self):
"""Returns directory in OSS uri string format of the original object."""
_, dirname, _ = self.parse_object_key()
dir_uri = f"oss://{self.bucket_name}{dirname}"
return dir_uri
@property
def uri(self) -> str:
"""Returns OSS uri in string format."""
return "oss://{bucket_name}/{object_key}".format(
bucket_name=self.bucket_name,
object_key=self.object_key,
)
def parse_object_key(self) -> Tuple[bool, str, str]:
"""Parse the OSS URI object key, returns a tuple of (is_dir, dir_path, file_name).
Returns:
namedtuple: An tuple of is_dir, dir_path, file_name.
"""
object_key = self.object_key.strip()
if object_key.endswith("/"):
is_dir, dir_path, file_name = True, os.path.join("/", object_key), None
else:
idx = object_key.rfind("/")
if idx < 0:
is_dir, dir_path, file_name = False, "/", object_key
else:
is_dir, dir_path, file_name = (
False,
os.path.join("/", object_key[: idx + 1]),
object_key[idx + 1 :],
)
return is_dir, dir_path, file_name
def _tar_file(source_file, target=None):
source_file = (
source_file if os.path.isabs(source_file) else os.path.abspath(source_file)
)
if not os.path.exists(source_file):
raise ValueError("source file not exists: %s", source_file)
if os.path.isdir(source_file):
arcname = ""
else:
arcname = os.path.basename(source_file)
if not target:
target = tempfile.mktemp()
with tarfile.open(target, "w:gz") as tar:
tar.add(name=source_file, arcname=arcname)
return target
def _get_bucket_and_path(
bucket: Optional[oss2.Bucket],
oss_path: Union[str, OssUriObj],
) -> Tuple[oss2.Bucket, str]:
from pai.session import get_default_session
sess = get_default_session()
if isinstance(oss_path, OssUriObj) or is_oss_uri(oss_path):
# If the parameter oss_path is an OssUriObj object, we need to use the
# corresponding bucket that OssUriObj instance belongs to for
# uploading/downloading the data.
if is_oss_uri(oss_path):
oss_path = OssUriObj(oss_path)
if sess.oss_bucket.bucket_name == oss_path.bucket_name:
bucket = sess.oss_bucket
else:
bucket = sess.get_oss_bucket(
oss_path.bucket_name, endpoint=oss_path.endpoint
)
oss_path = oss_path.object_key
elif not bucket:
bucket = sess.oss_bucket
return bucket, oss_path
def upload(
source_path: str,
oss_path: Union[str, OssUriObj],
bucket: Optional[oss2.Bucket] = None,
is_tar: Optional[bool] = False,
) -> str:
"""Upload local source file/directory to OSS.
Examples::
# compress and upload local directory `./src/` to OSS
>>> upload(source_path="./src/", oss_path="path/to/file",
... bucket=session.oss_bucket, is_tar=True)
Args:
source_path (str): Source file local path which needs to be uploaded, can be
a single file or a directory.
oss_path (Union[str, OssUriObj]): Destination OSS path.
bucket (oss2.Bucket): OSS bucket used to store the upload data. If it is not
provided, OSS bucket of the default session will be used.
is_tar (bool): Whether to compress the file before uploading (default: False).
Returns:
str: A string in OSS URI format. If the source_path is directory, return the
OSS URI representing the directory for uploaded data, else then
returns the OSS URI points to the uploaded file.
"""
bucket, oss_path = _get_bucket_and_path(bucket, oss_path)
source_path_obj = pathlib.Path(source_path)
if not source_path_obj.exists():
raise RuntimeError("Source path is not exist: {}".format(source_path))
if is_tar:
# compress the local data and upload the compressed source data.
with tempfile.TemporaryDirectory() as dir_name:
temp_tar_path = _tar_file(
source_path, os.path.join(dir_name, "source.tar.gz")
)
dest_path = (
os.path.join(oss_path, os.path.basename(temp_tar_path))
if oss_path.endswith("/")
else oss_path
)
_upload_with_progress(
filename=temp_tar_path, object_key=dest_path, oss_bucket=bucket
)
return "oss://{}/{}".format(bucket.bucket_name, dest_path)
elif not source_path_obj.is_dir():
# if source path is a file, just invoke bucket.put_object.
# if the oss_path is endswith slash, the file will be uploaded to
# "{oss_path}{filename}", else the file will be uploaded to "{oss_path}".
dest_path = (
os.path.join(oss_path, os.path.basename(source_path))
if oss_path.endswith("/")
else oss_path
)
_upload_with_progress(
filename=source_path, object_key=dest_path, oss_bucket=bucket
)
return "oss://{}/{}".format(bucket.bucket_name, dest_path)
else:
# if the source path is a directory, upload all the file under the directory.
source_files = glob.glob(
pathname=str(source_path_obj / "**"),
recursive=True,
)
if not oss_path.endswith("/"):
oss_path += "/"
files = [f for f in source_files if not os.path.isdir(f)]
for file_path in files:
file_path_obj = pathlib.Path(file_path)
file_relative_path = file_path_obj.relative_to(source_path_obj).as_posix()
object_key = oss_path + file_relative_path
_upload_with_progress(
filename=file_path, object_key=object_key, oss_bucket=bucket
)
return "oss://{}/{}".format(bucket.bucket_name, oss_path)
def download(
oss_path: Union[str, OssUriObj],
local_path: str,
bucket: Optional[oss2.Bucket] = None,
un_tar=False,
):
"""Download OSS objects to local path.
Args:
oss_path (str): Source OSS path, could be a single OSS object or a OSS
directory.
local_path (str): Local path used to store the data from OSS.
bucket (oss2.Bucket, optional): OSS bucket used to store the upload data. If it
is not provided, OSS bucket of the default session will be used.
un_tar (bool, optional): Whether to decompress the downloaded data. It is only
work for `oss_path` point to a single file that has a suffix "tar.gz".
Returns:
str: A local file path for the downloaded data.
"""
bucket, oss_path = _get_bucket_and_path(bucket, oss_path)
if not bucket.object_exists(oss_path) or oss_path.endswith("/"):
# The `oss_path` represents a "directory" in the OSS bucket, download the
# objects which object key is prefixed with `oss_path`.
# Note: `un_tar` is not work while `oss_path` is a directory.
oss_path += "/" if not oss_path.endswith("/") else ""
iterator = oss2.ObjectIteratorV2(
bucket=bucket,
prefix=oss_path,
)
keys = [obj.key for obj in iterator if not obj.key.endswith("/")]
for key in tqdm(keys, desc=f"Downloading: {oss_path}"):
rel_path = os.path.relpath(key, oss_path)
dest = os.path.join(local_path, rel_path)
os.makedirs(os.path.dirname(dest), exist_ok=True)
_download_with_progress(
dest,
object_key=key,
oss_bucket=bucket,
)
return local_path
else:
# The `oss_path` represents a single file in OSS bucket.
if oss_path.endswith(".tar.gz") and un_tar:
# currently, only tar.gz format is supported for un_tar after downloading.
with tempfile.TemporaryDirectory() as temp_dir:
target_path = os.path.join(temp_dir, os.path.basename(oss_path))
_download_with_progress(
target_path,
object_key=oss_path,
oss_bucket=bucket,
)
with tarfile.open(name=target_path, mode="r") as t:
t.extractall(path=local_path)
return local_path
else:
os.makedirs(local_path, exist_ok=True)
dest = os.path.join(local_path, os.path.basename(oss_path))
_download_with_progress(
dest,
object_key=oss_path,
oss_bucket=bucket,
)
return dest
class CredentialProviderWrapper(CredentialsProvider):
"""A wrapper class for the credential provider of OSS."""
def __init__(self, config: Union[CredentialConfig] = None):
self.client = CredentialClient(config)
def get_credentials(self) -> Credentials:
return Credentials(
access_key_id=self.client.get_access_key_id(),
access_key_secret=self.client.get_access_key_secret(),
security_token=self.client.get_security_token(),
)