pai/pipeline/types/spec.py (157 lines of code) (raw):
# Copyright 2023 Alibaba, Inc. or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
from abc import ABCMeta, abstractmethod
from collections import Counter
import six
from .artifact import LocationArtifactMetadata, PipelineArtifact
from .parameter import PipelineParameter
IO_TYPE_INPUTS = "inputs"
IO_TYPE_OUTPUTS = "outputs"
def sort_variable_by_category(items):
"""Sort variables by category."""
if not items:
return [], [], []
counter = Counter((item.name for item in items))
conflicts = {key for key, count in counter.items() if count > 1}
if conflicts:
raise ValueError("Parameter/Artifact names conflict:%s" % (",".join(conflicts)))
arts = [item for item in items if item.variable_category == "artifacts"]
params = [item for item in items if item.variable_category == "parameters"]
return params, arts, params + arts
class IndexedItemMixin(six.with_metaclass(ABCMeta, object)):
def __init__(self, items):
self._items = items
self._indexer = {self.index_key(item): idx for idx, item in enumerate(items)}
def __getitem__(self, key):
if isinstance(key, six.integer_types):
return self._items[key]
elif isinstance(key, six.string_types):
return self._items[self._indexer[key]]
elif isinstance(key, slice):
return self._items.__getitem__(key)
def __iter__(self):
return iter(self._items)
def __len__(self):
return len(self._items)
@abstractmethod
def index_key(self, item):
pass
class IOSpecBase(IndexedItemMixin):
"""Inputs/Outputs spec base."""
def __init__(self, items):
parameter_items, artifact_items, items = sort_variable_by_category(items)
self._parameters = Parameters(parameter_items)
self._artifacts = Artifacts(artifact_items)
super(IOSpecBase, self).__init__(items)
@staticmethod
def sort_items(items):
# ensure parameters is prior to artifacts
# `sorted` in Python is stable sort.
return sorted(items, key=lambda x: 0 if isinstance(x, PipelineParameter) else 1)
@property
def items(self):
return self._items
@property
def artifacts(self):
return self._artifacts
@property
def parameters(self):
return self._parameters
def __repr__(self):
return "%s:\n%s" % (
type(self).__name__,
"\n".join(["\t" + str(item) for item in self._items]),
)
def to_dict(self):
af_pos = next(
(
idx
for idx, item in enumerate(self._items)
if item.variable_category == "artifacts"
),
len(self._items),
)
d = {
"parameters": [param.to_dict() for param in self._items[:af_pos]],
"artifacts": [af.to_dict() for af in self._items[af_pos:]],
}
return d
def index_key(self, item):
return item.name
class InputsSpec(IOSpecBase):
"""Inputs spec for"""
def __init__(self, inputs):
super(InputsSpec, self).__init__(items=inputs)
def assign(self, inputs_args):
"""
Args:
inputs_args:
"""
assign_items = []
if isinstance(inputs_args, list):
for idx, arg in enumerate(inputs_args):
self._items[idx].assign(arg)
assign_items.append(self.items[idx])
elif isinstance(inputs_args, dict):
for k, v in inputs_args.items():
self._items[self._indexer[k]].assign(v)
assign_items.append(self._items[self._indexer[k]])
else:
raise ValueError(
"Unexpected input_args type:%s, required list or dict"
% type(inputs_args)
)
return assign_items
class OutputsSpec(IOSpecBase):
def __init__(self, outputs):
super(OutputsSpec, self).__init__(items=outputs)
for item in self.items:
item.kind = IO_TYPE_OUTPUTS
class Parameters(IndexedItemMixin):
def index_key(self, item):
return item.name
class Artifacts(IndexedItemMixin):
def index_key(self, item):
return item.name
def load_input_output_spec(p, spec):
inputs = []
outputs = []
spec = copy.deepcopy(spec)
for param in spec["inputs"].get("parameters", []):
inputs.append(_load_parameter_spec(p, param.copy(), "inputs"))
for af in spec["inputs"].get("artifacts", []):
inputs.append(_load_artifact_spec(p, af, "inputs"))
for param in spec["outputs"].get("parameters", []):
outputs.append(_load_parameter_spec(p, param, "outputs"))
for af in spec["outputs"].get("artifacts", []):
outputs.append(_load_artifact_spec(p, af, "outputs"))
return InputsSpec(inputs), OutputsSpec(outputs)
def _load_parameter_spec(p, param_spec, io_type):
typ = param_spec.pop("type", None)
name = param_spec.pop("name")
from_ = param_spec.pop("from", None)
feasible = param_spec.pop("feasible", None)
value = param_spec.pop("value", None)
desc = param_spec.pop("desc", None)
param = PipelineParameter(
name=name,
typ=typ,
default=value,
desc=desc,
io_type=io_type,
from_=from_,
parent=p,
feasible=feasible,
)
return param
def _load_artifact_spec(p, artifact_spec, io_type):
assert io_type in ("inputs", "outputs")
metadata = LocationArtifactMetadata.from_dict(artifact_spec.get("metadata", None))
name = artifact_spec.get("name", None)
from_ = artifact_spec.get("from", None)
value = artifact_spec.get("value", None)
desc = artifact_spec.get("desc", None)
required = artifact_spec.get("required", False)
repeated = artifact_spec.get("repeated", False)
af = PipelineArtifact(
name=name,
metadata=metadata,
io_type=io_type,
parent=p,
from_=from_,
value=value,
desc=desc,
required=required,
repeated=repeated,
)
return af