pai/pipeline/component/_registered.py (159 lines of code) (raw):
# Copyright 2023 Alibaba, Inc. or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import six
from pai.common.yaml_utils import dump as yaml_dump
from pai.common.yaml_utils import safe_load as yaml_safe_load
from pai.pipeline.component._base import ComponentBase, UnRegisteredComponent
from pai.pipeline.types.spec import load_input_output_spec
from pai.session import get_default_session
class RegisteredComponent(ComponentBase):
"""RegisteredComponent represent the pipeline schema from pipeline/component.
RegisteredComponent object include the definition of "Workflow" use in PAI pipeline service.
It could be fetched from remote pipeline service or construct from local Pipeline/Component.
Saved pipeline component has unique `pipeline_id` which is generated by pipeline service.
"""
def __init__(self, pipeline_id, manifest=None, workspace_id=None):
"""Template constructor.
Args:
manifest: "Workflow" definition of the pipeline.
pipeline_id: Unique ID for pipeline in PAI service.
workspace_id: ID of the workspace which the pipeline belongs to.
"""
if not manifest:
session = get_default_session()
manifest = session.pipeline_api.get_schema()["Manifest"]
if isinstance(manifest, six.string_types):
manifest = yaml_safe_load(manifest)
self._manifest = manifest
self._pipeline_id = pipeline_id
self._workspace_id = workspace_id
inputs, outputs = load_input_output_spec(self, manifest["spec"])
self._identifier = manifest["metadata"]["identifier"]
self._provider = manifest["metadata"]["provider"]
self._version = manifest["metadata"]["version"]
super(RegisteredComponent, self).__init__(
inputs=inputs,
outputs=outputs,
)
def __repr__(self):
return "%s:Id=%s,Identifier=%s,Provider=%s,Version=%s" % (
type(self).__name__,
self._pipeline_id,
self.identifier,
self.provider,
self.version,
)
def __eq__(self, other):
return isinstance(other, type(self)) and other.pipeline_id == self.pipeline_id
@property
def identifier(self):
return self._identifier
@property
def provider(self):
return self._provider
@property
def version(self):
return self._version
@property
def pipeline_id(self):
"""Unique ID of the pipeline in PAI pipeline service.
Returns:
str: Unique pipeline ID of the component instance.
"""
return self._pipeline_id
@property
def manifest(self):
"""Pipeline manifest schema.
Returns:
dict: Pipeline manifest schema in dict.
"""
return self._manifest
@property
def raw_manifest(self):
"""Pipeline manifest in YAML format
Returns:
str: Pipeline manifest.
"""
return yaml_dump(self._manifest)
@classmethod
def get_by_identifier(cls, identifier, provider=None, version="v1"):
"""Get SavedOperator with identifier-provider-version tuple.
Args:
identifier (str): Pipeline identifier.
provider (str): Provider of the Pipeline, account uid of the current session will be used as
default.
version (str): Version of the pipeline.
Returns:
pai.pipeline.SavedOperator: SavedOperator instance
"""
session = get_default_session()
provider = provider or session.provider
res = session.pipeline_api.get_by_identifier(
identifier=identifier,
provider=provider,
version=version,
)
if not res:
raise ValueError(
f"Not found the specific pipeline/component: "
f"identifier={identifier} provider={provider} version={version}"
)
pipeline_info = session.pipeline_api.get_schema(pipeline_id=res["PipelineId"])
if not pipeline_info:
raise ValueError(
"Not found pipeline with specific information: identifier={0}, provider={1}, version={2}".format(
identifier, provider, version
)
)
return cls(
manifest=pipeline_info["Manifest"],
pipeline_id=pipeline_info["PipelineId"],
)
@classmethod
def list(
cls,
identifier=None,
provider=None,
version=None,
session=None,
page_size=10,
page_number=1,
):
"""List the SavedOperator in PAI
Search the pipeline component available in remote PAI service. The method return a
generator used to traverse the SavedOperator set match the query condition.
Args:
identifier (str): Pipeline identifier filter.
provider (str): Pipeline provider filter.
version (str): Pipeline version.
workspace_id (str): Workspace id of the pipeline.
session: PAI session.
page_number (int):
page_size (int):
Yields:
pai.component.SavedOperator: SavedOperator match the query.
"""
session = session or get_default_session()
if not provider:
provider = session.provider
result = session.pipeline_api.list(
identifier=identifier,
provider=provider,
version=version,
page_size=page_size,
page_number=page_number,
)
return result.items
def update(self, component):
"""Update current registered component/pipeline using the manifest of given component/pipeline.
Args:
component (Union[UnRegisteredOperator, str, dict]): New pipeline/component spec,
could be an unregistered component, dict or yaml in str.
"""
session = get_default_session()
if isinstance(component, UnRegisteredComponent):
manifest = component.to_manifest(
identifier=self.identifier, version=self.version
)
elif isinstance(component, str):
manifest = component
elif isinstance(component, dict):
manifest = yaml_dump(component)
else:
raise ValueError(
"Please provider ContainerOperator, Pipeline or Manifest in string to update the registered component."
)
session.pipeline_api.update(self._pipeline_id, manifest)
def delete(self):
"""Delete this registered component/pipeline."""
get_default_session().pipeline_api.delete(self.pipeline_id)
@classmethod
def deserialize(cls, obj_dict):
manifest, id, workspace_id = (
obj_dict.get("Manifest"),
obj_dict["PipelineId"],
obj_dict.get("WorkspaceId", None),
)
return cls(
workspace_id=workspace_id,
pipeline_id=id,
manifest=manifest,
)
@classmethod
def _has_impl(cls, manifest):
if isinstance(manifest, six.string_types):
manifest = yaml_safe_load(manifest)
if "spec" not in manifest:
return False
spec = manifest["spec"]
if "pipelines" not in spec and "container" not in spec:
return False
return True
@classmethod
def get(cls, pipeline_id, session=None):
"""Get SavedOperator with pipeline_id.
Args:
pipeline_id (str): Unique pipeline id.
Returns:
pai.pipeline.SavedOperator: SavedOperator instance with the
specific pipeline_id
"""
session = session or get_default_session()
return cls.deserialize(session.pipeline_api.get_schema(pipeline_id=pipeline_id))
def save(self, identifier=None, version=None):
raise NotImplementedError("SaveTemplate is not savable.")
def _submit(self, job_name, args):
from pai.pipeline.run import PipelineRun
session = get_default_session()
run_id = PipelineRun.run(
job_name,
args,
no_confirm_required=True,
pipeline_id=self._pipeline_id,
session=session,
)
return run_id
def io_spec_to_dict(self):
return {
"inputs": self.inputs.to_dict(),
"outputs": self.outputs.to_dict(),
}