pai/api/base.py (113 lines of code) (raw):
# Copyright 2023 Alibaba, Inc. or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABCMeta
from typing import Any, Dict, List, Optional, Union
import six
from alibabacloud_tea_openapi.client import Client
from alibabacloud_tea_util.models import RuntimeOptions
from six import with_metaclass
from Tea.model import TeaModel
from ..common.logging import get_logger
logger = get_logger(__name__)
class ServiceName(object):
# Service provided by PAI.
PAI_DLC = "pai-dlc"
PAI_EAS = "pai-eas"
PAI_WORKSPACE = "aiworkspace"
PAI_STUDIO = "pai"
PAIFLOW = "paiflow"
# Other services provided by Alibaba Cloud.
STS = "sts"
PAI_DSW = "pai-dsw"
class PAIRestResourceTypes(object):
"""Resource types provided by PAI REST API."""
Dataset = "Dataset"
DlcJob = "DlcJob"
CodeSource = "CodeSource"
Image = "Image"
Service = "Service"
Model = "Model"
Workspace = "Workspace"
Algorithm = "Algorithm"
TrainingJob = "TrainingJob"
Pipeline = "Pipeline"
PipelineRun = "PipelineRun"
TensorBoard = "TensorBoard"
Experiment = "Experiment"
class ResourceAPI(with_metaclass(ABCMeta, object)):
"""Class that provide APIs to operate the resource."""
BACKEND_SERVICE_NAME = None
def __init__(
self,
acs_client: Client,
header: Optional[Dict[str, str]] = None,
runtime: Optional[RuntimeOptions] = None,
):
"""Initialize a ResourceAPI object.
Args:
acs_client (Client): A basic client used to communicate with a specific PAI
service.
header (Dict[str, str], optional): Header set in the HTTP request. Defaults
to None.
runtime (RuntimeOptions, optional): Options configured for the client
runtime behavior, such as read_timeout, connection_timeout, etc.
Defaults to None.
"""
self.acs_client = acs_client
self.header = header
self.runtime = runtime
def _make_extra_request_options(self):
"""Returns headers and runtime for client."""
return self.header or dict(), self.runtime or RuntimeOptions()
def _do_request(self, method_: str, *args, **kwargs):
headers, runtime = self._make_extra_request_options()
if "headers" not in kwargs:
kwargs["headers"] = headers
if "runtime" not in kwargs:
kwargs["runtime"] = runtime
request_method = getattr(self.acs_client, method_)
return request_method(*args, **kwargs).body
def get_api_object_by_resource_id(self, resource_id):
raise NotImplementedError
def refresh_entity(self, id_, entity):
"""Refresh entity using API object from service."""
if not isinstance(id_, six.string_types) and not isinstance(
id_, six.integer_types
):
raise ValueError(
"Expected integer type or string type for id, but given type %s"
% type(id_)
)
api_obj = self.get_api_object_by_resource_id(resource_id=id_)
return entity.patch_from_api_object(api_obj)
@classmethod
def make_paginated_result(
cls,
data: Union[Dict[str, Any], TeaModel],
item_key=None,
) -> "PaginatedResult":
"""Make a paginated result from response.
Args:
data: Response data.
item_key:
Returns:
"""
if isinstance(data, TeaModel):
data = data.to_map()
total_count = data.pop("TotalCount")
if item_key:
items = data[item_key]
else:
values = list([val for val in data.values() if isinstance(val, list)])
if len(values) != 1:
raise ValueError("Requires item key to make paginated result.")
items = values[0]
return PaginatedResult(items=items, total_count=total_count)
class WorkspaceScopedResourceAPI(with_metaclass(ABCMeta, ResourceAPI)):
"""Workspace Scoped Resource API."""
# A workspace_id placeholder indicate the workspace_id field of
# the request should not be replaced.
workspace_id_none_placeholder = "WORKSPACE_ID_NONE_PLACEHOLDER"
# Default parameter name for request object.
default_param_name_for_request = "request"
def __init__(self, workspace_id, acs_client, **kwargs):
super(WorkspaceScopedResourceAPI, self).__init__(
acs_client=acs_client, **kwargs
)
self.workspace_id = workspace_id
def _do_request(self, method_, **kwargs):
request = kwargs.get(self.default_param_name_for_request)
if not request:
# Sometimes, request object is not named as "request", we need to find it.
for param_name, param_value in kwargs.items():
if isinstance(param_value, TeaModel) and type(
param_value
).__name__.endswith("Request"):
request = param_value
break
# Automatically configure the workspace ID for the request
if request and hasattr(request, "workspace_id"):
if request.workspace_id is None:
request.workspace_id = self.workspace_id
elif (
request.workspace_id == self.workspace_id_none_placeholder
or not request.workspace_id
):
# request.workspace_id is 0 or request.workspace_id is empty string,
# we do not inject workspace_id of the scope.
request.workspace_id = None
return super(WorkspaceScopedResourceAPI, self)._do_request(method_, **kwargs)
class PaginatedResult(object):
"""A class represent response of a pagination call to PAI service."""
items: List[Union[Dict[str, Any], str]] = None
total_count: int = None
def __init__(self, items: List[Union[Dict[str, Any], str]], total_count: int):
self.items = items
self.total_count = total_count