pai/pipeline/component/_registered.py (159 lines of code) (raw):

# Copyright 2023 Alibaba, Inc. or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import six from pai.common.yaml_utils import dump as yaml_dump from pai.common.yaml_utils import safe_load as yaml_safe_load from pai.pipeline.component._base import ComponentBase, UnRegisteredComponent from pai.pipeline.types.spec import load_input_output_spec from pai.session import get_default_session class RegisteredComponent(ComponentBase): """RegisteredComponent represent the pipeline schema from pipeline/component. RegisteredComponent object include the definition of "Workflow" use in PAI pipeline service. It could be fetched from remote pipeline service or construct from local Pipeline/Component. Saved pipeline component has unique `pipeline_id` which is generated by pipeline service. """ def __init__(self, pipeline_id, manifest=None, workspace_id=None): """Template constructor. Args: manifest: "Workflow" definition of the pipeline. pipeline_id: Unique ID for pipeline in PAI service. workspace_id: ID of the workspace which the pipeline belongs to. """ if not manifest: session = get_default_session() manifest = session.pipeline_api.get_schema()["Manifest"] if isinstance(manifest, six.string_types): manifest = yaml_safe_load(manifest) self._manifest = manifest self._pipeline_id = pipeline_id self._workspace_id = workspace_id inputs, outputs = load_input_output_spec(self, manifest["spec"]) self._identifier = manifest["metadata"]["identifier"] self._provider = manifest["metadata"]["provider"] self._version = manifest["metadata"]["version"] super(RegisteredComponent, self).__init__( inputs=inputs, outputs=outputs, ) def __repr__(self): return "%s:Id=%s,Identifier=%s,Provider=%s,Version=%s" % ( type(self).__name__, self._pipeline_id, self.identifier, self.provider, self.version, ) def __eq__(self, other): return isinstance(other, type(self)) and other.pipeline_id == self.pipeline_id @property def identifier(self): return self._identifier @property def provider(self): return self._provider @property def version(self): return self._version @property def pipeline_id(self): """Unique ID of the pipeline in PAI pipeline service. Returns: str: Unique pipeline ID of the component instance. """ return self._pipeline_id @property def manifest(self): """Pipeline manifest schema. Returns: dict: Pipeline manifest schema in dict. """ return self._manifest @property def raw_manifest(self): """Pipeline manifest in YAML format Returns: str: Pipeline manifest. """ return yaml_dump(self._manifest) @classmethod def get_by_identifier(cls, identifier, provider=None, version="v1"): """Get SavedOperator with identifier-provider-version tuple. Args: identifier (str): Pipeline identifier. provider (str): Provider of the Pipeline, account uid of the current session will be used as default. version (str): Version of the pipeline. Returns: pai.pipeline.SavedOperator: SavedOperator instance """ session = get_default_session() provider = provider or session.provider res = session.pipeline_api.get_by_identifier( identifier=identifier, provider=provider, version=version, ) if not res: raise ValueError( f"Not found the specific pipeline/component: " f"identifier={identifier} provider={provider} version={version}" ) pipeline_info = session.pipeline_api.get_schema(pipeline_id=res["PipelineId"]) if not pipeline_info: raise ValueError( "Not found pipeline with specific information: identifier={0}, provider={1}, version={2}".format( identifier, provider, version ) ) return cls( manifest=pipeline_info["Manifest"], pipeline_id=pipeline_info["PipelineId"], ) @classmethod def list( cls, identifier=None, provider=None, version=None, session=None, page_size=10, page_number=1, ): """List the SavedOperator in PAI Search the pipeline component available in remote PAI service. The method return a generator used to traverse the SavedOperator set match the query condition. Args: identifier (str): Pipeline identifier filter. provider (str): Pipeline provider filter. version (str): Pipeline version. workspace_id (str): Workspace id of the pipeline. session: PAI session. page_number (int): page_size (int): Yields: pai.component.SavedOperator: SavedOperator match the query. """ session = session or get_default_session() if not provider: provider = session.provider result = session.pipeline_api.list( identifier=identifier, provider=provider, version=version, page_size=page_size, page_number=page_number, ) return result.items def update(self, component): """Update current registered component/pipeline using the manifest of given component/pipeline. Args: component (Union[UnRegisteredOperator, str, dict]): New pipeline/component spec, could be an unregistered component, dict or yaml in str. """ session = get_default_session() if isinstance(component, UnRegisteredComponent): manifest = component.to_manifest( identifier=self.identifier, version=self.version ) elif isinstance(component, str): manifest = component elif isinstance(component, dict): manifest = yaml_dump(component) else: raise ValueError( "Please provider ContainerOperator, Pipeline or Manifest in string to update the registered component." ) session.pipeline_api.update(self._pipeline_id, manifest) def delete(self): """Delete this registered component/pipeline.""" get_default_session().pipeline_api.delete(self.pipeline_id) @classmethod def deserialize(cls, obj_dict): manifest, id, workspace_id = ( obj_dict.get("Manifest"), obj_dict["PipelineId"], obj_dict.get("WorkspaceId", None), ) return cls( workspace_id=workspace_id, pipeline_id=id, manifest=manifest, ) @classmethod def _has_impl(cls, manifest): if isinstance(manifest, six.string_types): manifest = yaml_safe_load(manifest) if "spec" not in manifest: return False spec = manifest["spec"] if "pipelines" not in spec and "container" not in spec: return False return True @classmethod def get(cls, pipeline_id, session=None): """Get SavedOperator with pipeline_id. Args: pipeline_id (str): Unique pipeline id. Returns: pai.pipeline.SavedOperator: SavedOperator instance with the specific pipeline_id """ session = session or get_default_session() return cls.deserialize(session.pipeline_api.get_schema(pipeline_id=pipeline_id)) def save(self, identifier=None, version=None): raise NotImplementedError("SaveTemplate is not savable.") def _submit(self, job_name, args): from pai.pipeline.run import PipelineRun session = get_default_session() run_id = PipelineRun.run( job_name, args, no_confirm_required=True, pipeline_id=self._pipeline_id, session=session, ) return run_id def io_spec_to_dict(self): return { "inputs": self.inputs.to_dict(), "outputs": self.outputs.to_dict(), }