pai/api/image.py (125 lines of code) (raw):

# Copyright 2023 Alibaba, Inc. or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Any, Dict, List, Union from ..libs.alibabacloud_aiworkspace20210204.models import ( ListImageLabelsRequest, ListImageLabelsResponseBody, ListImagesRequest, ListImagesResponseBody, ) from .base import PaginatedResult, ServiceName, WorkspaceScopedResourceAPI SUPPORTED_IMAGE_FRAMEWORKS = [ "DeepRec", "DeepSpeed", "Megatron-LM", "ModelScope", "Nemo", "OneFlow", "PyTorch", "TensorFlow", "Transformers", "XGBoost", ] SUPPORTED_IMAGE_LANGUAGES = [ "python", # TODO: "Python" ] class ImageLabel(object): """Image Label Class.""" # Unofficial Image Label UNOFFICIAL_LABEL = "system.official=false" # Official Image Label OFFICIAL_LABEL = "system.official=true" # PAI Image Label PROVIDER_PAI_LABEL = "system.origin=PAI" # Community Image Label PROVIDER_COMMUNITY_LABEL = "system.origin=Community" # DLC Image Label: for training DLC_LABEL = "system.supported.dlc=true" # EAS Image Label: for inference EAS_LABEL = "system.supported.eas=true" # DSW Image Label: for develop DSW_LABEL = "system.supported.dsw=true" # Accelerator: Use GPU DEVICE_TYPE_GPU = "system.chipType=GPU" DEVICE_TYPE_CPU = "system.chipType=CPU" # Python Version # TODO: delete this label key PYTHON_VERSION = "system.pythonVersion" @staticmethod def framework_version( framework: str, version: str, ): """Create a label for filtering images that support specific framework version. Args: framework (str): framework name, which is case sensitive. version (str): framework version. If version is '*', it will match all versions. Returns: str: framework version label string. Raises: ValueError: If the framework is not supported. """ if framework not in SUPPORTED_IMAGE_FRAMEWORKS: raise ValueError( f"Unsupported framework: {framework}. Current supported frameworks are:" f" {SUPPORTED_IMAGE_FRAMEWORKS}" ) return f"system.framework.{framework}={version}" @staticmethod def language_version( language: str, version: str, ): """Create a label for filtering images that support specific language version. Args: language (str): language name, which is case sensitive. version (str): language version. If version is '*', it will match all versions. Returns: str: language version label string. Raises: ValueError: If the language is not supported. """ if language not in SUPPORTED_IMAGE_LANGUAGES: raise ValueError( f"Unsupported language: {language}. Current supported languages are:" f" {SUPPORTED_IMAGE_LANGUAGES}" ) # TODO: "system.language.{language}={version}" return f"system.{language}Version={version}" class ImageAPI(WorkspaceScopedResourceAPI): """Class which provide API to operate CodeSource resource.""" BACKEND_SERVICE_NAME = ServiceName.PAI_WORKSPACE _list_method = "list_images_with_options" _create_method = "create_image_with_options" _list_labels_method = "list_image_labels_with_options" def list( self, labels: Union[Dict[str, Any], List[str]] = None, name: str = None, order: str = "DESC", page_number: int = 1, page_size: int = 50, parent_user_id: str = None, query: str = None, sort_by: str = None, user_id: str = None, verbose: bool = False, **kwargs, ) -> PaginatedResult: """List image resources.""" workspace_id = kwargs.pop("workspace_id", None) if isinstance(labels, dict): labels = ",".join(["{}={}".format(k, v) for k, v in labels.items()]) elif isinstance(labels, list): labels = ",".join([item for item in labels]) req = ListImagesRequest( labels=labels, name=name, order=order, page_number=page_number, page_size=page_size, parent_user_id=parent_user_id, query=query, sort_by=sort_by, user_id=user_id, verbose=verbose, workspace_id=workspace_id, ) return self._list(request=req) def _list(self, request) -> PaginatedResult: resp: ListImagesResponseBody = self._do_request( self._list_method, request=request ) return self.make_paginated_result(resp) def list_labels( self, image_id: str = None, label_filter: Union[Dict[str, Any], List[str]] = None, label_keys: str = None, region: str = None, **kwargs, ) -> dict: workspace_id = kwargs.pop("workspace_id", None) if isinstance(label_filter, dict): label_filter = ",".join( ["{}={}".format(k, v) for k, v in label_filter.items()] ) elif isinstance(label_filter, list): label_filter = ",".join([item for item in label_filter]) request = ListImageLabelsRequest( image_id=image_id, label_filter=label_filter, label_keys=label_keys, region=region, workspace_id=workspace_id, ) resp: ListImageLabelsResponseBody = self._do_request( method_=self._list_labels_method, request=request ) return resp.to_map()["Labels"]