pai/api/model.py (223 lines of code) (raw):
# Copyright 2023 Alibaba, Inc. or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import typing
from typing import Any, Dict
from ..libs.alibabacloud_aiworkspace20210204.models import (
CreateModelRequest,
CreateModelVersionRequest,
Label,
ListModelsRequest,
ListModelsResponseBody,
ListModelVersionsRequest,
ListModelVersionsResponseBody,
UpdateModelVersionRequest,
)
from .base import PaginatedResult, ServiceName, WorkspaceScopedResourceAPI
if typing.TYPE_CHECKING:
pass
class ModelAPI(WorkspaceScopedResourceAPI):
BACKEND_SERVICE_NAME = ServiceName.PAI_WORKSPACE
_create_model_method = "create_model_with_options"
_list_model_method = "list_models_with_options"
_get_model_method = "get_model_with_options"
_delete_model_method = "delete_model_with_options"
_create_model_version_method = "create_model_version_with_options"
_list_model_version_method = "list_model_versions_with_options"
_get_model_version_method = "get_model_version_with_options"
_update_model_version_method = "update_model_version_with_options"
_delete_model_version_method = "delete_model_version_with_options"
def create(
self,
accessibility: str = None,
domain: str = None,
labels: Dict[str, str] = None,
model_description: str = None,
model_doc: str = None,
model_name: str = None,
origin: str = None,
task: str = None,
workspace_id: str = None,
) -> str:
labels = [Label(key=k, value=v) for k, v in labels.items()] if labels else []
request = CreateModelRequest(
accessibility=accessibility,
domain=domain,
labels=labels,
model_description=model_description,
model_doc=model_doc,
model_name=model_name,
origin=origin,
task=task,
workspace_id=workspace_id,
)
resp = self._do_request(self._create_model_method, request=request)
return resp.model_id
def list(
self,
collections: str = None,
domain: str = None,
label: str = None,
label_string: str = None,
labels: str = None,
model_name: str = None,
order: str = None,
origin: str = None,
page_number: int = None,
page_size: int = None,
provider: str = None,
query: str = None,
sort_by: str = None,
task: str = None,
workspace_id: str = None,
) -> PaginatedResult:
request = ListModelsRequest(
collections=collections,
domain=domain,
label=label,
label_string=label_string,
labels=labels,
model_name=model_name,
order=order,
origin=origin,
page_number=page_number,
page_size=page_size,
provider=provider,
query=query,
sort_by=sort_by,
task=task,
workspace_id=workspace_id,
)
resp: ListModelsResponseBody = self._do_request(
self._list_model_method, request=request
)
return self.make_paginated_result(resp)
def get(self, model_id: str):
resp = self._do_request(method_=self._get_model_method, model_id=model_id)
return resp.to_map()
def delete(self, model_id: str):
self._do_request(method_=self._delete_model_method, model_id=model_id)
def create_version(
self,
model_id: str,
approval_status: str = None,
evaluation_spec: Dict[str, Any] = None,
format_type: str = None,
framework_type: str = None,
inference_spec: Dict[str, Any] = None,
labels: Dict[str, str] = None,
metrics: Dict[str, Any] = None,
options: str = None,
source_id: str = None,
source_type: str = None,
training_spec: Dict[str, Any] = None,
uri: str = None,
version_description: str = None,
version_name: str = None,
):
"""Create a ModeVersion resource."""
labels = [Label(key=k, value=v) for k, v in labels.items()] if labels else []
request = CreateModelVersionRequest(
approval_status=approval_status,
evaluation_spec=evaluation_spec,
format_type=format_type,
framework_type=framework_type,
inference_spec=inference_spec,
labels=labels,
metrics=metrics,
options=options,
source_id=source_id,
source_type=source_type,
training_spec=training_spec,
uri=uri,
version_description=version_description,
version_name=version_name,
)
response = self._do_request(
self._create_model_version_method, model_id=model_id, request=request
)
version_name = response.to_map()["VersionName"]
return version_name
def list_versions(
self,
model_id,
approval_status: str = None,
format_type: str = None,
framework_type: str = None,
label: str = None,
label_string: str = None,
labels: str = None,
order: str = None,
page_number: int = None,
page_size: int = None,
sort_by: str = None,
source_id: str = None,
source_type: str = None,
version_name: str = None,
) -> PaginatedResult:
request = ListModelVersionsRequest(
approval_status=approval_status,
format_type=format_type,
framework_type=framework_type,
label=label,
label_string=label_string,
labels=labels,
order=order,
page_number=page_number,
page_size=page_size,
sort_by=sort_by,
source_id=source_id,
source_type=source_type,
version_name=version_name,
)
resp: ListModelVersionsResponseBody = self._do_request(
self._list_model_version_method, model_id=model_id, request=request
)
data = resp.to_map()
for v in data["Versions"]:
v.update(
{
"ModelId": model_id,
}
)
return self.make_paginated_result(data)
def get_version(self, model_id: str, version: str):
resp = self._do_request(
self._get_model_version_method, model_id=model_id, version_name=version
)
obj = resp.to_map()
obj.update({"ModelId": model_id})
return obj
def update_version(
self,
model_id: str,
version: str,
approval_status: str = None,
evaluation_spec: Dict[str, Any] = None,
inference_spec: Dict[str, Any] = None,
metrics: Dict[str, Any] = None,
options: str = None,
source_id: str = None,
source_type: str = None,
training_spec: Dict[str, Any] = None,
version_description: str = None,
):
request = UpdateModelVersionRequest(
approval_status=approval_status,
evaluation_spec=evaluation_spec,
inference_spec=inference_spec,
metrics=metrics,
options=options,
source_id=source_id,
source_type=source_type,
training_spec=training_spec,
version_description=version_description,
)
self._do_request(
self._update_model_version_method,
model_id=model_id,
version_name=version,
request=request,
)
def delete_version(self, model_id: str, version: str):
self._do_request(
self._delete_model_version_method,
model_id=model_id,
version_name=version,
)