pai/pipeline/types/variable.py (111 lines of code) (raw):

# Copyright 2023 Alibaba, Inc. or its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # https://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from abc import ABCMeta, abstractmethod from six import with_metaclass class PipelineVariable(with_metaclass(ABCMeta, object)): """Base class of Artifact and PipelineParameter.""" variable_category = None def __init__( self, name, desc=None, io_type="inputs", value=None, from_=None, required=None, parent=None, validator=None, ): """ Args: name: name of parameter. desc: parameter description. io_type: usage of PipelineParameter in pipeline, either "input" or "output" value: default value of parameter from_: required: validator: parameter value validator. parent: """ self.name = name self.io_type = io_type self.desc = desc self.value = value self.from_ = from_ self.required = required self.parent = parent self.validator = validator def __hash__(self): return id(self) # TODO: validate if pipeline variable attribute is legal def _validate_spec(self): pass @abstractmethod def validate_value(self, val): pass @abstractmethod def validate_from(self, val): pass def assign(self, arg): from .artifact import PipelineArtifactElement from .parameter import LoopItemPlaceholder if isinstance(arg, LoopItemPlaceholder): self.value = arg.enclosed_fullname elif not isinstance(arg, (PipelineVariable, PipelineArtifactElement)): if not self.validate_value(arg): raise ValueError("Arg:%s is invalid value for %s" % (arg, self)) self.value = arg else: if not self.validate_from(arg): raise ValueError( "invalid assignment. %s left: %s, right: %s" % (self.fullname, self.typ, arg.typ) ) self.from_ = arg @property def fullname(self): """Unique identifier in pipeline manifest for PipelineVariable""" from pai.pipeline.component._base import ComponentBase if self.parent and not isinstance(self.parent, ComponentBase): return ".".join( [self.parent.ref_name, self.io_type, self.variable_category, self.name] ) else: return ".".join([self.io_type, self.variable_category, self.name]) def bind(self, parent, io_type): if self.parent and parent != self.parent: raise ValueError( "Pipeline variable has bound to another operator instance." ) self.parent = parent self.io_type = io_type @property def enclosed_fullname(self): return "{{%s}}" % self.fullname def __str__(self): return self.__repr__() def __repr__(self): return "%s:{Name:%s, Kind:%s, Required:%s, Value:%s, Desc:%s}" % ( type(self).__name__, self.name, self.io_type, self.required, self.value, self.desc, ) def translate_argument(self, value): arguments = { "name": self.name, "value": value, } return arguments def to_argument(self): argument = {"name": self.name} if self.from_: argument["from"] = self.from_.enclosed_fullname else: argument["value"] = self.value return argument def depend_steps(self): from pai.pipeline import PipelineStep if self.from_ and self.from_.parent and isinstance(self.from_, PipelineStep): return self.from_ def to_dict(self): d = { "name": self.name, } if self.validator: d["feasible"] = self.validator.to_dict() if self.desc: d["desc"] = self.desc elif self.from_ is not None: if isinstance(self.from_, PipelineVariable): d["from"] = "{{%s}}" % self.from_.fullname else: d["from"] = self.from_ return d