pai/common/utils.py (178 lines of code) (raw):
# Copyright 2023 Alibaba, Inc. or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import functools
import importlib.util
import random
import re
import socket
import string
import sys
import time
import warnings
from datetime import datetime
from functools import lru_cache
from typing import Callable, Dict, List, Optional, Union
from semantic_version import Version
from pai.common.consts import (
INSTANCE_TYPE_LOCAL,
INSTANCE_TYPE_LOCAL_GPU,
FileSystemInputScheme,
)
from pai.version import VERSION
DEFAULT_PLAIN_TEXT_ALLOW_CHARACTERS = string.ascii_letters + string.digits + "_"
def is_iterable(arg):
try:
_ = iter(arg)
return True
except TypeError:
return False
def random_str(n):
"""Random string generation with lower case letters and digits.
Args:
n: Size of generated random string.
Returns:
str: generated random string.
"""
return "".join(
random.choice(string.ascii_lowercase + string.digits) for _ in range(n)
)
def camel_to_snake(name):
"""Convert a name from camel case to snake case."""
name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name)
return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower()
def snake_to_camel(name):
"""Convert a name from snake case to camel case."""
return "".join([w.title() for w in name.split("_")])
def make_list_resource_iterator(method: Callable, **kwargs):
"""Wrap resource list method as an iterator.
Args:
method: Resource List method.
**kwargs: arguments for the method.
Yields:
A resource iterator.
"""
from pai.api.base import PaginatedResult
page_number = kwargs.get("page_number", 1)
page_size = kwargs.get("page_size", 10)
while True:
kwargs.update(page_number=page_number, page_size=page_size)
result = method(**kwargs)
if isinstance(result, PaginatedResult):
total_count = result.total_count
result = result.items
else:
total_count = None
for item in result:
yield item
if len(result) == 0 or len(result) < page_size:
return
if total_count and page_number * page_size >= total_count:
return
page_number += 1
def to_plain_text(
input_str: str, allowed_characters=DEFAULT_PLAIN_TEXT_ALLOW_CHARACTERS, repl_ch="_"
):
"""Replace characters in input_str if it is not in allowed_characters."""
return "".join([c if c in allowed_characters else repl_ch for c in input_str])
def http_user_agent(user_agent: Optional[Union[Dict, str]] = None) -> str:
"""Generate HTTP User-Agent that represents current client."""
ua = f"pai-python-sdk/{VERSION}; python/{sys.version.split()[0]}"
if isinstance(user_agent, dict):
ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items())
elif isinstance(user_agent, str):
ua += "; " + user_agent
return ua
def is_notebook() -> bool:
"""Return True if current environment is notebook."""
try:
from IPython import get_ipython
shell = get_ipython().__class__.__name__
for parent_cls in shell.__mro__:
if parent_cls.__name__ == "ZMQInteractiveShell":
return True
return False
except (NameError, ImportError):
return False
def is_local_run_instance_type(instance_type: str) -> bool:
"""Return True if instance_type is local run instance type."""
return instance_type and instance_type.strip() in [
INSTANCE_TYPE_LOCAL_GPU,
INSTANCE_TYPE_LOCAL,
]
def generate_repr(repr_obj, *attr_names: str, **kwargs) -> str:
"""Generate a string representation of the given object.
Args:
repr_obj: The object used to generate the string representation.
attr_names: A list of attribute names to include in the string representation.
Returns:
str: A string representation of the object.
"""
attrs = {name: getattr(repr_obj, name) for name in attr_names}
attrs.update(kwargs)
attr_repr = ", ".join(["{}={}".format(k, v) for k, v in attrs.items()])
cls_name = repr_obj.__class__.__name__
return f"{cls_name}({attr_repr})"
def to_semantic_version(version_str: str) -> Version:
"""Convert version_str to semantic version.
Convert version_str to semantic version, if version_str is not a valid
semantic version, return '0.0.0'.
Args:
version_str[str]: Version string, such as '1.0.0', '1.0.0-rc1', '1.0.0+build.1'.
Returns:
:class:`semantic_version.Version`: Semantic version.
"""
try:
return Version.coerce(version_str)
except ValueError:
# if version_str is not a valid semantic version, return '0.0.0'
return Version.coerce("0.0.0")
def is_odps_table_uri(uri: str) -> bool:
"""Return True if uri is an odps table input URI.
Args:
uri (str): URI of input table, such as 'odps://<project_name>/tables/<table_name>'.
Examples:
>>> is_odps_table_uri('odps://<project_name>/tables/<table_name>')
True
"""
if not uri.startswith("odps://"):
return False
info = uri[7:].split("/", 2)
if len(info) != 3:
return False
return info[1] == "tables"
def is_filesystem_uri(uri: str) -> bool:
"""Return True if uri is a filesystem input URI.
Args:
uri (str): URI of input NAS, such as 'nas://<FileSystemId>/path/to/data/directory/'.
Examples:
# Standard or Extreme file system type
>>> is_filesystem_uri('nas://<FileSystemId>/path/to/data/directory/')
True
# CPFS file system type
>>> is_filesystem_uri('cpfs://<FileSystemId>/<ProtocolServiceId>/<ExportId>')
True
"""
schemas = {
v for k, v in FileSystemInputScheme.__dict__.items() if not k.startswith("__")
}
return any(uri.startswith(f"{schema}://") for schema in schemas)
def is_dataset_id(item: str) -> bool:
"""Return True if given input is a dataset ID.
Args:
item (str): user input dataset ID or Dataset ID and dataset version, separated by a slash.
Examples:
>>> is_dataset_id('d-ybko3rap60c4gs9flc')
True
>>> is_dataset_id('d-ybko3rap60c4gs9flc/v1')
True
"""
return item.startswith("d-")
def is_nas_uri(uri: Union[str, bytes]) -> bool:
"""Determines whether the given uri is a NAS uri.
Args:
uri (Union[str, bytes]): A string in NAS URI schema:
nas://29**d-b12****446.cn-hangzhou.nas.aliyuncs.com/data/path/
nas://29****123-y**r.cn-hangzhou.extreme.nas.aliyuncs.com/data/path/
Returns:
bool: True if the given uri is a NAS uri, else False.
"""
return bool(uri and isinstance(uri, (str, bytes)) and str(uri).startswith("nas://"))
@lru_cache()
def is_domain_connectable(domain: str, port: int = 80, timeout: int = 1) -> bool:
"""Check if the domain is connectable."""
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
# Set the timeout for the socket
sock.settimeout(timeout)
try:
# Get the IP address of the domain
ip = socket.gethostbyname(domain)
# Try to connect to the IP address on specific port (default 80, HTTP)
sock.connect((ip, port))
# If the connection is successful, return True
return True
except (socket.timeout, socket.gaierror, socket.error):
# If there is an error connecting, return False
return False
finally:
# Close the socket
sock.close()
def experimental(callable_entity):
"""Decorator to mark functions or classes as experimental"""
@functools.wraps(callable_entity)
def wrapper(*args, **kwargs):
message = f"{callable_entity.__name__} is experimental and may change or be removed in future releases."
warnings.warn(message, category=FutureWarning, stacklevel=2)
return callable_entity(*args, **kwargs)
return wrapper
def retry(max_attempts=3, wait_secs=1, exceptions=(Exception,), report_retries=True):
"""Decorator to make functions retry by config"""
def decorator_retry(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
attempts = 0
while attempts < max_attempts:
try:
result = func(*args, **kwargs)
return result
except exceptions as e:
attempts += 1
if attempts == max_attempts:
raise # Re-raise the last exception when all the attempts have failed
if report_retries:
warnings.warn(f"Retry {attempts}/{max_attempts} failed: {e}")
time.sleep(wait_secs)
return wrapper
return decorator_retry
def print_table(headers: List[str], rows: List[List[str]]):
"""Give headers and rows, print as table to stdout."""
length = len(headers)
for row in rows:
if len(row) != length:
raise ValueError("Unable to print table, headers length mismatch with rows")
column_widths = [
max(len(str(value)) for value in column) for column in zip(headers, *rows)
]
header_row = " | ".join(
f"{header:<{column_widths[i]}}" for i, header in enumerate(headers)
)
print(header_row)
print("-" * len(header_row))
for row in rows:
print(
" | ".join(
f"{str(value):<{column_widths[i]}}" for i, value in enumerate(row)
)
)
def is_package_available(package_name: str) -> bool:
"""Check if the package is available in the current environment."""
return True if importlib.util.find_spec(package_name) is not None else False
def timestamp(sep: str = "-", utc: bool = False) -> str:
"""Return a timestamp with millisecond precision.
Args:
sep: The separator between date and time.
utc: Whether to use UTC time.
Returns:
str: A timestamp with millisecond precision.
"""
if utc:
res = datetime.utcnow().strftime("%Y%m%d-%H%M%S-%f")[:-3]
else:
res = datetime.now().strftime("%Y%m%d-%H%M%S-%f")[:-3]
if sep != "-":
res = res.replace("-", sep)
return res
def name_from_base(base_name: str, sep: str = "-") -> str:
"""Return a name with base_name and timestamp.
Args:
base_name: The base name of the returned name.
sep: The separator between base_name and timestamp.
Returns:
str: A name with base_name and timestamp.
"""
return "{base_name}{sep}{timestamp}".format(
base_name=base_name, sep=sep, timestamp=timestamp(sep=sep, utc=False)
)