pai/pipeline/component/_base.py (200 lines of code) (raw):

# Copyright 2023 Alibaba, Inc. or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import itertools import uuid from abc import ABCMeta, abstractmethod import six from ...common.logging import get_logger from ...common.utils import random_str from ...common.yaml_utils import dump as yaml_dump from ...common.yaml_utils import dump_all as yaml_dump_all from ...session import get_default_session from ..types import IO_TYPE_INPUTS, IO_TYPE_OUTPUTS, InputsSpec, OutputsSpec logger = get_logger(__name__) DEFAULT_PIPELINE_API_VERSION = "core/v1" class ComponentBase(six.with_metaclass(ABCMeta, object)): def __init__( self, inputs, outputs, ): self._inputs = ( inputs if isinstance(inputs, InputsSpec) else InputsSpec(inputs or []) ) for input in self._inputs: input.bind(self, IO_TYPE_INPUTS) self._outputs = ( outputs if isinstance(outputs, OutputsSpec) else OutputsSpec(outputs or []) ) for output in self._outputs: output.bind(self, IO_TYPE_OUTPUTS) @property def inputs(self): """Inputs Spec of the operator. Returns: pai.pipeline.types.spec.InputsSpec: Inputs of the operator. """ return self._inputs @property def outputs(self): """Outputs Spec of the operator. Returns: pai.pipeline.types.spec.OutputsSpec: Outputs of the operator """ return self._outputs def translate_arguments(self, args): parameters, artifacts = [], [] if not args: return parameters, artifacts requires = set([af.name for af in self.inputs.artifacts if af.required]) not_supply = requires - set(args.keys()) if len(not_supply) > 0: raise ValueError( "Required arguments is not supplied:%s" % ",".join(not_supply) ) name_var_mapping = { item.name: item for item in itertools.chain(self.inputs, self.outputs) } for name, arg in args.items(): if name not in name_var_mapping: logger.error( "Provider useless argument:%s, it is not require by the pipeline manifest spec." % name ) raise ValueError("provided argument is not required:%s" % name) variable = name_var_mapping[name] value = variable.translate_argument(arg) if variable.variable_category == "artifacts": artifacts.append(value) else: parameters.append(value) repeated_artifacts = [ af for af in itertools.chain(self.inputs.artifacts, self.outputs.artifacts) if af.repeated and af.count ] for af in repeated_artifacts: if af.name not in args: artifacts.append({"name": af.name, "value": [None] * af.count}) return parameters, artifacts def set_artifact_count(self, artifact_name, count): """Set the count of repeated artifact in operator run. Args: artifact_name: output repeated artifact name. count: """ artifacts = { item.name: item for item in itertools.chain(self.outputs.artifacts, self.inputs.artifacts) } artifact = artifacts.get(artifact_name) if not artifact: raise ValueError("artifact is not exists: %s" % artifact_name) if not artifact.repeated: raise ValueError("artifact is not repeated: %s", artifact_name) artifact.count = count return self def spec_to_dict(self): spec = {"inputs": self.inputs.to_dict(), "outputs": self.outputs.to_dict()} return spec def to_dict(self): data = { "apiVersion": DEFAULT_PIPELINE_API_VERSION, "metadata": self.metadata_to_dict(), "spec": self.spec_to_dict(), } return data def run( self, job_name=None, wait=True, arguments=None, show_outputs=True, **kwargs ): """Run the operator using the definition in SavedOperator and given arguments. Args: job_name (str): Name of the submit pipeline run job. arguments (dict): Inputs arguments used in the run workflow. wait (bool): Wait util the job stop(succeed or failed or terminated). show_outputs (bool): Show the outputs of the job. Returns: pai.pipeline.run.PipelineRun: PipelineRun instance of the submit job. """ from pai.pipeline import PipelineRun parameters, artifacts = self.translate_arguments(arguments) pipeline_args = { "parameters": parameters, "artifacts": artifacts, } run_id = self._submit(job_name=job_name, args=pipeline_args) run_instance = PipelineRun.get(run_id=run_id) if not wait: return run_instance run_instance.wait_for_completion(show_outputs=show_outputs) return run_instance def as_step(self, name=None, inputs=None, depends=None): """Create a PipelineStep instance using the operator.""" from pai.pipeline import PipelineStep return PipelineStep( inputs=inputs, component=self, name=name, depends=depends, ) def as_loop_step(self, name, items, parallelism=None, inputs=None, depends=None): """Create a LoopStep instance using the operator.""" from pai.pipeline.step import LoopStep return LoopStep( component=self, name=name, parallelism=parallelism, items=items, inputs=inputs, depends=depends, ) def as_condition_step(self, name, condition, inputs=None, depends=None): """Create a conditional step using the operator.""" from pai.pipeline.step import ConditionStep return ConditionStep( component=self, name=name, condition=condition, inputs=inputs, depends=depends, ) class UnRegisteredComponent(six.with_metaclass(ABCMeta, ComponentBase)): def __init__(self, inputs, outputs): super(UnRegisteredComponent, self).__init__(inputs=inputs, outputs=outputs) self._guid = uuid.uuid4().hex self._name = "tmp-{}".format(random_str(16)) @property def guid(self): return self._guid def metadata_to_dict(self): # Hack: PAIFlow Service require field name in metadata dict return { "name": self._name, "guid": self._guid, } @property def name(self): return self._name def save(self, identifier, version): """Save the Pipeline in PAI service for reuse or share it with others. By specific the identifier, version and upload the manifest, the PipelineTemplate instance is store into the remote service and return the pipeline_id of the saved PipelineTemplate. Account UID in Alibaba Cloud is use as the provider of the saved operator by default. Saved PipelineTemplate could be fetch using the pipeline_id or the specific identifier-provider-version. Args: identifier (str): The identifier of the saved pipeline. version (str): Version of the saved pipeline. Returns: pai.pipeline.SavedTemplate: Saved PipelineTemplate instance (with pipeline_id generate by remote service). """ from pai.pipeline.component import RegisteredComponent if not identifier or not version: raise ValueError( "Please provide the identifier and version for the operator." ) manifest = self.to_manifest(identifier=identifier, version=version) session = get_default_session() pipeline_id = session.pipeline_api.create(manifest) return RegisteredComponent.get(pipeline_id) @classmethod def _patch_metadata(cls, pipeline_spec): if isinstance(pipeline_spec, dict): pipeline_spec["metadata"]["identifier"] = "tmp-%s" % random_str(16) pipeline_spec["metadata"]["version"] = "v0" elif isinstance(pipeline_spec, list): pipeline_spec[-1]["metadata"]["identifier"] = "tmp-%s" % random_str(16) pipeline_spec[-1]["metadata"]["version"] = "v0" return pipeline_spec def _submit(self, job_name, args): from pai.pipeline.run import PipelineRun session = get_default_session() pipeline_spec = self._patch_metadata(self.to_dict()) if isinstance(pipeline_spec, dict): manifest = yaml_dump(pipeline_spec) # A Pipeline spec may contain unregistered operators. Such a pipeline spec is a list format. elif isinstance(pipeline_spec, list): manifest = yaml_dump_all(pipeline_spec) else: raise ValueError( "No support pipeline spec value type: %s" % (type(pipeline_spec)) ) run_id = PipelineRun.run( job_name, args, no_confirm_required=True, manifest=manifest, session=session, ) return run_id def io_spec_to_dict(self): return { "inputs": self.inputs.to_dict(), "outputs": self.outputs.to_dict(), } @abstractmethod def to_manifest(self, identifier, version): pass def export_manifest(self, file_path, identifier, version): manifest = self.to_manifest(identifier=identifier, version=version) with open(file_path, "w") as f: f.write(manifest)