pai/common/utils.py (178 lines of code) (raw):

# Copyright 2023 Alibaba, Inc. or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import import functools import importlib.util import random import re import socket import string import sys import time import warnings from datetime import datetime from functools import lru_cache from typing import Callable, Dict, List, Optional, Union from semantic_version import Version from pai.common.consts import ( INSTANCE_TYPE_LOCAL, INSTANCE_TYPE_LOCAL_GPU, FileSystemInputScheme, ) from pai.version import VERSION DEFAULT_PLAIN_TEXT_ALLOW_CHARACTERS = string.ascii_letters + string.digits + "_" def is_iterable(arg): try: _ = iter(arg) return True except TypeError: return False def random_str(n): """Random string generation with lower case letters and digits. Args: n: Size of generated random string. Returns: str: generated random string. """ return "".join( random.choice(string.ascii_lowercase + string.digits) for _ in range(n) ) def camel_to_snake(name): """Convert a name from camel case to snake case.""" name = re.sub("(.)([A-Z][a-z]+)", r"\1_\2", name) return re.sub("([a-z0-9])([A-Z])", r"\1_\2", name).lower() def snake_to_camel(name): """Convert a name from snake case to camel case.""" return "".join([w.title() for w in name.split("_")]) def make_list_resource_iterator(method: Callable, **kwargs): """Wrap resource list method as an iterator. Args: method: Resource List method. **kwargs: arguments for the method. Yields: A resource iterator. """ from pai.api.base import PaginatedResult page_number = kwargs.get("page_number", 1) page_size = kwargs.get("page_size", 10) while True: kwargs.update(page_number=page_number, page_size=page_size) result = method(**kwargs) if isinstance(result, PaginatedResult): total_count = result.total_count result = result.items else: total_count = None for item in result: yield item if len(result) == 0 or len(result) < page_size: return if total_count and page_number * page_size >= total_count: return page_number += 1 def to_plain_text( input_str: str, allowed_characters=DEFAULT_PLAIN_TEXT_ALLOW_CHARACTERS, repl_ch="_" ): """Replace characters in input_str if it is not in allowed_characters.""" return "".join([c if c in allowed_characters else repl_ch for c in input_str]) def http_user_agent(user_agent: Optional[Union[Dict, str]] = None) -> str: """Generate HTTP User-Agent that represents current client.""" ua = f"pai-python-sdk/{VERSION}; python/{sys.version.split()[0]}" if isinstance(user_agent, dict): ua += "; " + "; ".join(f"{k}/{v}" for k, v in user_agent.items()) elif isinstance(user_agent, str): ua += "; " + user_agent return ua def is_notebook() -> bool: """Return True if current environment is notebook.""" try: from IPython import get_ipython shell = get_ipython().__class__.__name__ for parent_cls in shell.__mro__: if parent_cls.__name__ == "ZMQInteractiveShell": return True return False except (NameError, ImportError): return False def is_local_run_instance_type(instance_type: str) -> bool: """Return True if instance_type is local run instance type.""" return instance_type and instance_type.strip() in [ INSTANCE_TYPE_LOCAL_GPU, INSTANCE_TYPE_LOCAL, ] def generate_repr(repr_obj, *attr_names: str, **kwargs) -> str: """Generate a string representation of the given object. Args: repr_obj: The object used to generate the string representation. attr_names: A list of attribute names to include in the string representation. Returns: str: A string representation of the object. """ attrs = {name: getattr(repr_obj, name) for name in attr_names} attrs.update(kwargs) attr_repr = ", ".join(["{}={}".format(k, v) for k, v in attrs.items()]) cls_name = repr_obj.__class__.__name__ return f"{cls_name}({attr_repr})" def to_semantic_version(version_str: str) -> Version: """Convert version_str to semantic version. Convert version_str to semantic version, if version_str is not a valid semantic version, return '0.0.0'. Args: version_str[str]: Version string, such as '1.0.0', '1.0.0-rc1', '1.0.0+build.1'. Returns: :class:`semantic_version.Version`: Semantic version. """ try: return Version.coerce(version_str) except ValueError: # if version_str is not a valid semantic version, return '0.0.0' return Version.coerce("0.0.0") def is_odps_table_uri(uri: str) -> bool: """Return True if uri is an odps table input URI. Args: uri (str): URI of input table, such as 'odps://<project_name>/tables/<table_name>'. Examples: >>> is_odps_table_uri('odps://<project_name>/tables/<table_name>') True """ if not uri.startswith("odps://"): return False info = uri[7:].split("/", 2) if len(info) != 3: return False return info[1] == "tables" def is_filesystem_uri(uri: str) -> bool: """Return True if uri is a filesystem input URI. Args: uri (str): URI of input NAS, such as 'nas://<FileSystemId>/path/to/data/directory/'. Examples: # Standard or Extreme file system type >>> is_filesystem_uri('nas://<FileSystemId>/path/to/data/directory/') True # CPFS file system type >>> is_filesystem_uri('cpfs://<FileSystemId>/<ProtocolServiceId>/<ExportId>') True """ schemas = { v for k, v in FileSystemInputScheme.__dict__.items() if not k.startswith("__") } return any(uri.startswith(f"{schema}://") for schema in schemas) def is_dataset_id(item: str) -> bool: """Return True if given input is a dataset ID. Args: item (str): user input dataset ID or Dataset ID and dataset version, separated by a slash. Examples: >>> is_dataset_id('d-ybko3rap60c4gs9flc') True >>> is_dataset_id('d-ybko3rap60c4gs9flc/v1') True """ return item.startswith("d-") def is_nas_uri(uri: Union[str, bytes]) -> bool: """Determines whether the given uri is a NAS uri. Args: uri (Union[str, bytes]): A string in NAS URI schema: nas://29**d-b12****446.cn-hangzhou.nas.aliyuncs.com/data/path/ nas://29****123-y**r.cn-hangzhou.extreme.nas.aliyuncs.com/data/path/ Returns: bool: True if the given uri is a NAS uri, else False. """ return bool(uri and isinstance(uri, (str, bytes)) and str(uri).startswith("nas://")) @lru_cache() def is_domain_connectable(domain: str, port: int = 80, timeout: int = 1) -> bool: """Check if the domain is connectable.""" sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # Set the timeout for the socket sock.settimeout(timeout) try: # Get the IP address of the domain ip = socket.gethostbyname(domain) # Try to connect to the IP address on specific port (default 80, HTTP) sock.connect((ip, port)) # If the connection is successful, return True return True except (socket.timeout, socket.gaierror, socket.error): # If there is an error connecting, return False return False finally: # Close the socket sock.close() def experimental(callable_entity): """Decorator to mark functions or classes as experimental""" @functools.wraps(callable_entity) def wrapper(*args, **kwargs): message = f"{callable_entity.__name__} is experimental and may change or be removed in future releases." warnings.warn(message, category=FutureWarning, stacklevel=2) return callable_entity(*args, **kwargs) return wrapper def retry(max_attempts=3, wait_secs=1, exceptions=(Exception,), report_retries=True): """Decorator to make functions retry by config""" def decorator_retry(func): @functools.wraps(func) def wrapper(*args, **kwargs): attempts = 0 while attempts < max_attempts: try: result = func(*args, **kwargs) return result except exceptions as e: attempts += 1 if attempts == max_attempts: raise # Re-raise the last exception when all the attempts have failed if report_retries: warnings.warn(f"Retry {attempts}/{max_attempts} failed: {e}") time.sleep(wait_secs) return wrapper return decorator_retry def print_table(headers: List[str], rows: List[List[str]]): """Give headers and rows, print as table to stdout.""" length = len(headers) for row in rows: if len(row) != length: raise ValueError("Unable to print table, headers length mismatch with rows") column_widths = [ max(len(str(value)) for value in column) for column in zip(headers, *rows) ] header_row = " | ".join( f"{header:<{column_widths[i]}}" for i, header in enumerate(headers) ) print(header_row) print("-" * len(header_row)) for row in rows: print( " | ".join( f"{str(value):<{column_widths[i]}}" for i, value in enumerate(row) ) ) def is_package_available(package_name: str) -> bool: """Check if the package is available in the current environment.""" return True if importlib.util.find_spec(package_name) is not None else False def timestamp(sep: str = "-", utc: bool = False) -> str: """Return a timestamp with millisecond precision. Args: sep: The separator between date and time. utc: Whether to use UTC time. Returns: str: A timestamp with millisecond precision. """ if utc: res = datetime.utcnow().strftime("%Y%m%d-%H%M%S-%f")[:-3] else: res = datetime.now().strftime("%Y%m%d-%H%M%S-%f")[:-3] if sep != "-": res = res.replace("-", sep) return res def name_from_base(base_name: str, sep: str = "-") -> str: """Return a name with base_name and timestamp. Args: base_name: The base name of the returned name. sep: The separator between base_name and timestamp. Returns: str: A name with base_name and timestamp. """ return "{base_name}{sep}{timestamp}".format( base_name=base_name, sep=sep, timestamp=timestamp(sep=sep, utc=False) )