pai/image.py (226 lines of code) (raw):

# Copyright 2023 Alibaba, Inc. or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import re from typing import Any, Dict, List, Optional from .api.image import SUPPORTED_IMAGE_FRAMEWORKS, ImageLabel from .common.logging import get_logger from .common.utils import make_list_resource_iterator, to_semantic_version from .session import Session, get_default_session logger = get_logger(__name__) _NORMALIZED_FRAMEWORK_NAMES = { name.lower(): name for name in SUPPORTED_IMAGE_FRAMEWORKS } # Regex expression pattern for PAI Docker Image Tag. _PAI_IMAGE_TAG_PATTERN_TRAINING = re.compile( r"([\w._-]+)-(gpu|cpu|mkl-cpu)-(py\d+)(?:-(cu\d+))?-([\S]+)" ) _PAI_IMAGE_TAG_PATTERN_INFERENCE = re.compile( r"([\w._-]+)-(py\d+)(?:-(gpu|cpu|mkl-cpu))?(?:-(cu\d+))?-([\S]+)" ) # Regex expression pattern for PAI Docker Image URI. _PAI_IMAGE_URI_PATTERN = re.compile(r"([\S]+)/([\S]+)/([\S]+):([\S]+)") class ImageInfo(object): """This class represents information for an image provided by PAI. Args: image_name (str): The name of the image. image_uri (str): The URI of the image. framework_name (str): The name of the framework installed in the image. framework_version (str, optional): The version of the framework (Default None). image_scope (str): The scope of the image, could be 'training', 'inference' or 'develop'. accelerator_type (str, optional): The type of accelerator. Defaults to None. python_version (str, optional): The version of Python. Defaults to None. """ def __repr__(self): return ( "{}(framework_name={}: framework_version={}: image_scope={}: " "accelerator_type={}: py_version={})".format( self.__class__.__name__, self.framework_name, self.framework_version, self.image_scope, self.accelerator_type, self.python_version, ) ) def __init__( self, image_name: str, image_uri: str, framework_name: str, image_scope: str, framework_version: str = None, accelerator_type: Optional[str] = None, python_version: Optional[str] = None, ): self.image_name = image_name self.image_uri = image_uri self.framework_name = framework_name self.framework_version = framework_version self.accelerator_type = accelerator_type self.python_version = python_version self.image_scope = image_scope class ImageScope(object): """Class containing constants that indicate the purpose of an image.""" TRAINING = "training" """Indicates the image is used for submitting a training job.""" INFERENCE = "inference" """Indicates the image is used for creating a prediction service.""" DEVELOP = "develop" """Indicates the image is used for running in DSW.""" _SCOPE_IMAGE_LABEL_MAPPING = { TRAINING: ImageLabel.DLC_LABEL, INFERENCE: ImageLabel.EAS_LABEL, DEVELOP: ImageLabel.DSW_LABEL, } @classmethod def to_image_label(cls, scope: str): cls._validate(scope) return cls._SCOPE_IMAGE_LABEL_MAPPING.get(scope.lower()) @classmethod def _validate(cls, scope: str): items = cls._SCOPE_IMAGE_LABEL_MAPPING.keys() if scope.lower() not in items: raise ValueError(f"Not supported image scope: {scope}") def _make_image_info( image_obj: Dict[str, Any], image_scope: str, ) -> Optional[ImageInfo]: """Make a ImageProperties object by parsing the image_uri.""" labels = {lb["Key"]: lb["Value"] for lb in image_obj["Labels"]} image_uri = image_obj["ImageUri"] match = _PAI_IMAGE_URI_PATTERN.match(image_uri) if not match: # ignore if image uri is not recognized logger.debug( "Could not recognize the given image uri, ignore the image:" f" image_uri={image_uri}" ) return host, namespace, repo_name, tag = match.groups() tag_match = _PAI_IMAGE_TAG_PATTERN_TRAINING.match( tag ) or _PAI_IMAGE_TAG_PATTERN_INFERENCE.match(tag) if not tag_match: # ignore if image tag is not recognized logger.debug( f"Could not recognize the given image tag, ignore the image:" f" image_uri={image_uri}." ) fw_version, cpu_or_gpu, _, cuda_version, os_version = [None] * 5 else: ( fw_version, cpu_or_gpu, _, cuda_version, os_version, ) = tag_match.groups() # use image label as ground truth to set the image property, python version, etc. labels = labels or dict() if labels.get("system.chipType") == "GPU": cpu_or_gpu = "GPU" elif labels.get("system.chipType") == "CPU": cpu_or_gpu = "CPU" py_version = labels.get(ImageLabel.PYTHON_VERSION) # TODO: get the framework name from Image Label # extract framework name from image repo if repo_name.endswith("-inference"): framework_name = repo_name[:-10] elif repo_name.endswith("-training"): framework_name = repo_name[:-9] else: framework_name = repo_name framework_name = _NORMALIZED_FRAMEWORK_NAMES.get(framework_name, framework_name) image_name = image_obj["Name"] return ImageInfo( image_name=image_name, image_uri=image_uri, framework_name=framework_name, framework_version=fw_version, accelerator_type=cpu_or_gpu, python_version=py_version, image_scope=image_scope, ) def _list_images( labels: List[str], session: Session, name: Optional[str] = None, page_number=1, page_size=50, ): gen = make_list_resource_iterator( session.image_api.list, name=name, labels=labels, verbose=True, # set the workspace_id manually, prevent using the default workspace of the # session. workspace_id=0, order="DESC", sort_by="GmtCreateTime", page_number=page_number, page_size=page_size, ) return gen def retrieve( framework_name: str, framework_version: str, accelerator_type: str = "CPU", image_scope: Optional[str] = ImageScope.TRAINING, session: Optional[Session] = None, ) -> ImageInfo: """Get a container image URI that satisfies the specified requirements. Examples:: # get a TensorFlow image with specific version for training. retrieve(framework_name="TensorFlow", framework_version="2.3") # get the latest PyTorch image that supports GPU for inference. retrieve( framework_name="PyTorch", framework_version="latest", accelerator_type="GPU", scope=ImageScope.INFERENCE, ) Args: framework_name (str): The name of the framework. Possible values include TensorFlow, XGBoost, PyTorch, OneFlow, and others. framework_version (str): The version of the framework to use. Get the latest version supported in PAI by set the parameters as 'latest'. image_scope (str, optional): The scope of the image to use. Possible values include 'training', 'inference', and 'develop'. accelerator_type (str, optional): The name of the accelerator to use. Possible values including 'CPU', and 'GPU', (Default CPU). session (:class:`pai.session.Session`, optional): A session object to interact with the PAI Service. If not provided, a default session will be used. Returns: ImageInfo: A object contains information of the image that satisfy the requirements. Raises: RuntimeError: A RuntimeErrors is raised if the specific image is not found. """ session = session or get_default_session() framework_name = framework_name.lower() supports_fw = [fw.lower() for fw in SUPPORTED_IMAGE_FRAMEWORKS] if framework_name not in supports_fw: raise ValueError( f"The framework ({framework_name}) is not supported by the" f" retrieve method: supported frameworks" f" {', '.join(SUPPORTED_IMAGE_FRAMEWORKS)}", ) # label filter used to list official images of specific scope. labels = [ImageLabel.OFFICIAL_LABEL, ImageScope.to_image_label(image_scope)] # if accelerator_type is not specified, use CPU image by default. if not accelerator_type or accelerator_type.lower() == "cpu": labels.append(ImageLabel.DEVICE_TYPE_CPU) elif accelerator_type.lower() == "gpu": labels.append(ImageLabel.DEVICE_TYPE_GPU) else: raise ValueError( f"Given accelerator type ({accelerator_type}) is not supported, only" f" CPU and GPU is supported." ) resp = _list_images(name=framework_name, labels=labels, session=session) # extract image properties, such as framework version, py_version, os_version, etc, # from image tag. candidates = [] for image_item in resp: image_info = _make_image_info( image_obj=image_item, image_scope=image_scope, ) if image_info.framework_name.lower() != framework_name.lower(): continue candidates.append(image_info) if not candidates: raise RuntimeError( f"Not found any image that satisfy the requirements: framework_name=" f"{framework_name}, accelerator={accelerator_type}" ) if framework_version.lower() == "latest": # select the latest framework version. candidates = sorted( candidates, key=lambda img: to_semantic_version(img.framework_version), reverse=True, ) return candidates[0] else: # find the image with the specific framework version. img = next( (img for img in candidates if img.framework_version == framework_version), None, ) if not img: supported_versions = [img.framework_version for img in candidates] raise RuntimeError( f"Not found the specific framework: framework_name={framework_name}, " f"framework_version={framework_version}, supported versions for the" f" framework are {','.join(supported_versions)} " ) else: return img def list_images( framework_name: str, session: Optional[Session] = None, image_scope: Optional[str] = ImageScope.TRAINING, ) -> List[ImageInfo]: """List available images provided by PAI. Args: framework_name (str): The name of the framework. Possible values include TensorFlow, XGBoost, PyTorch, OneFlow, and others. image_scope (str, optional): The scope of the image to use. Possible values include 'training', 'inference', and 'develop'. session (:class:`pai.session.Session`): A session object used to interact with the PAI Service. If not provided, a default session is used. Returns: List[ImageInfo]: A list of image URIs. """ session = session or get_default_session() if not framework_name or not framework_name.strip(): framework_name = None else: framework_name = framework_name.strip().lower() labels = [ ImageScope.to_image_label(image_scope), ImageLabel.OFFICIAL_LABEL, ] images = _list_images(labels=labels, session=session) images = [ _make_image_info( item, image_scope=image_scope, ) for item in images ] if framework_name: return [img for img in images if img.framework_name.lower() == framework_name] else: return images