pai/pipeline/types/parameter.py (283 lines of code) (raw):
# Copyright 2023 Alibaba, Inc. or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import absolute_import
import re
from decimal import Decimal
from enum import Enum
import six
from pai.common.utils import is_iterable
from pai.pipeline.types.variable import PipelineVariable
_int_float_types = tuple(list(six.integer_types) + list([float]))
_PRIMITIVE_TYPE_MAP = {
"Int": six.integer_types,
"Double": _int_float_types,
"Bool": bool,
"Str": six.string_types,
}
_NEGATIVE_INFINITY = Decimal("-infinity")
_POSITIVE_INFINITY = Decimal("infinity")
class Variable(object):
pass
class LoopItemPlaceholder(object):
@property
def enclosed_fullname(self):
return "{{item}}"
@property
def fullname(self):
return "item"
class ConditionExpr(object):
"""Represent condition which used in ConditionStep."""
def __init__(self, op, left, right):
self.op = op
self.left = left
self.right = right
def to_expr(self):
left_str = (
self.left.enclosed_fullname
if isinstance(self.left, PipelineParameter)
else str(self.left)
)
right_str = (
self.right if isinstance(self.right, PipelineParameter) else str(self.right)
)
return "{} {} {}".format(left_str, self.op, right_str)
def get_depends_steps(self):
from pai.pipeline import PipelineStep
def _get_step(item):
if (
isinstance(item, PipelineParameter)
and item.parent
and isinstance(item.parent, PipelineStep)
):
return item.parent
return list(filter(None, [_get_step(self.left), _get_step(self.right)]))
class LoopItems(object):
LOOP_ITEM_LIST = 0
LOOP_RANGE = 1
LOOP_PARAMETER = 2
def __init__(self, items):
if isinstance(items, range):
if items.step == 1:
self.type = self.LOOP_RANGE
self.items = items
else:
self.type = self.LOOP_RANGE
self.items = list(items)
elif isinstance(items, PipelineParameter):
self.type = type(self).LOOP_PARAMETER
self.items = items
elif is_iterable(items):
self.items = list(iter(items))
self.type = self.LOOP_ITEM_LIST
else:
raise ValueError("Not supported loop item type: %s", type(items))
def to_dict(self):
if self.type == self.LOOP_RANGE:
d = {
"withSequence": {
"start": self.items.start,
"end": self.items.stop,
}
}
elif self.type == self.LOOP_PARAMETER:
d = {
"withParam": self.items.enclosed_fullname,
}
else:
d = {"withItems": self.items}
return d
class PipelineParameter(PipelineVariable):
"""Definition of the input/output parameter using in pipeline."""
variable_category = "parameters"
def __init__(
self,
name,
typ=str,
default=None,
desc=None,
io_type="inputs",
from_=None,
parent=None,
feasible=None,
path=None,
):
"""
Returns:
object:
"""
typ = ParameterType.normalize_typ(typ)
validator = None
if feasible:
validator = ParameterValidator.load(feasible)
if default is not None:
required = False
else:
required = True
super(PipelineParameter, self).__init__(
name=name,
value=default,
desc=desc,
io_type=io_type,
from_=from_,
required=required,
parent=parent,
validator=validator,
)
self.typ = typ
self.path = path
@property
def default(self):
return self.value
def validate_value(self, val):
if self.typ in _PRIMITIVE_TYPE_MAP:
# error hint because of pycharm bug
# https://stackoverflow.com/questions/56493140/parameterized-generics-cannot-be-used-with-class-or-instance-checks
if not isinstance(val, _PRIMITIVE_TYPE_MAP[self.typ]):
return False
if self.validator and not self.validator.validate(val):
return False
return True
def validate_from(self, arg):
if not isinstance(arg, PipelineParameter):
raise ValueError(
"arg is expected to be type of 'PipelineParameter' "
"but was actually of type '%s'" % type(arg)
)
if arg.typ is not None and self.typ is not None and arg.typ != self.typ:
return False
return True
def to_dict(self):
d = super(PipelineParameter, self).to_dict()
d["type"] = self.typ.value
if self.value is not None:
d["value"] = self.value
if self.path is not None:
d["path"] = self.path
# if self.required:
# d["required"] = self.required
return d
def translate_argument(self, value):
arguments = {
"name": self.name,
"value": value,
}
return arguments
def __hash__(self):
return id(self)
def __eq__(self, other):
return ConditionExpr("==", self, other)
def __ne__(self, other):
return ConditionExpr("!=", self, other)
def __lt__(self, other):
return ConditionExpr("<", self, other)
def __le__(self, other):
return ConditionExpr("<=", self, other)
def __gt__(self, other):
return ConditionExpr(">", self, other)
def __ge__(self, other):
return ConditionExpr(">=", self, other)
class ParameterValidator(object):
def __init__(self, interval=None):
self.interval = interval
@classmethod
def load(cls, feasible):
validator = cls()
if "range" in feasible:
validator.interval = Interval.load(feasible["range"])
return validator
def validate(self, value):
if self.interval and not self.interval.validate(value):
return False
return True
def to_dict(self):
return {
"range": str(self.interval),
}
class Interval(object):
"""Range validator of pipeline parameter."""
_NUMBER_REGEXP = r"(-?(?:(?:[\d]+(\.[\d]*)?)|INF))"
INTERVAL_PATTERN = re.compile(
r"^([(\[])\s*{number_pattern}\s*,\s*{number_pattern}([)\]])$".format(
number_pattern=_NUMBER_REGEXP
)
)
def __init__(self, min_, max_, min_inclusive, max_inclusive):
self.min_ = min_
self.max_ = max_
self.min_inclusive = min_inclusive
self.max_inclusive = max_inclusive
def __str__(self):
return "{left}{min_}, {max_}{right}".format(
left="[" if self.min_inclusive else "(",
min_=self.value_str(self.min_),
max_=self.value_str(self.max_),
right="]" if self.max_inclusive else ")",
)
@staticmethod
def value_str(val):
if val == _POSITIVE_INFINITY:
return "INF"
elif val == _NEGATIVE_INFINITY:
return "-INF"
else:
return str(val)
@classmethod
def load(cls, feasible):
m = cls.INTERVAL_PATTERN.match(feasible)
if not m:
raise ValueError("parameter feasible %s not match pattern" % feasible)
left, min_, min_fraction, max_, max_fraction, right = m.groups()
if min_fraction:
min_ = Decimal(min_)
elif min_ not in ("-INF", "INF"):
min_ = int(min_)
else:
min_ = Decimal(min_)
if max_fraction:
max_ = Decimal(max_)
elif max_ not in ("-INF", "INF"):
max_ = int(max_)
else:
max_ = Decimal(max_)
interval = Interval(min_, max_, left == "[", right == "]")
if not interval._validate_bound():
raise ValueError(
"invalid range: lower bound greater than upper bound is not allowed"
)
return interval
def _validate_bound(self):
if Decimal(self.min_) > Decimal(self.max_):
return False
if self.min_ == self.max_ and not (self.min_inclusive or self.max_inclusive):
return False
return True
def validate(self, val):
if self.min_ < val < self.max_:
return True
if self.min_ == val and self.min_inclusive:
return True
if self.max_ == val and self.max_inclusive:
return True
return False
_ParameterTypeMapping = {
"long": "Int",
"integer": "Int",
"int": "Int",
"double": "Double",
"float": "Double",
"string": "String",
"str": "String",
"bool": "Bool",
"boolean": "Bool",
"map": "Map",
"dict": "Map",
"array": "Array",
"list": "Array",
}
class ParameterType(Enum):
String = "String"
Integer = "Int"
Double = "Double"
Bool = "Bool"
Map = "Map"
Array = "Array"
@classmethod
def normalize_typ(cls, typ_instance):
if isinstance(typ_instance, type):
type_name = typ_instance.__name__.lower()
elif isinstance(typ_instance, six.string_types):
type_name = typ_instance.lower()
elif isinstance(typ_instance, cls):
return typ_instance
else:
raise ValueError(
"Not Supported PipelineParameter Type: {typ}".format(typ=typ_instance)
)
return ParameterType(_ParameterTypeMapping[type_name])