in Synthesis_incorporation/value_search/operation_filtering.py [0:0]
def add_filters_to_function_operation(function_operation):
"""Adds filters to the FunctionOperation depending on its FilterGroup."""
group = function_operation.function_info.filter_group
if group == filter_group.FilterGroup.NONE:
# Do nothing.
pass
elif group == filter_group.FilterGroup.SHAPE_1:
function_operation.add_value_filters([SHAPE_FILTER])
elif group == filter_group.FilterGroup.TENSOR_1:
function_operation.add_value_filters([TENSOR_FILTER])
elif group == filter_group.FilterGroup.TENSORSEQUENCE_1:
function_operation.add_value_filters([TENSOR_SEQUENCE_FILTER])
elif group == filter_group.FilterGroup.FLOATTENSOR_1:
function_operation.add_value_filters([FLOAT_TENSOR_FILTER])
elif group == filter_group.FilterGroup.NUMERICTENSOR_1:
function_operation.add_value_filters([NUMERIC_TENSOR_FILTER])
elif group == filter_group.FilterGroup.PRIMITIVE_OR_TENSOR_1:
function_operation.add_value_filters([PRIMITIVE_OR_TENSOR_FILTER])
elif group == filter_group.FilterGroup.TENSOR_AXIS_2:
function_operation.add_value_filters([TENSOR_FILTER, AXIS_FILTER])
function_operation.set_apply_filter(TENSOR_AXIS_IN_RANGE_APPLY_FILTER)
elif group == filter_group.FilterGroup.NUMERICTENSOR_AXIS_2:
function_operation.add_value_filters([NUMERIC_TENSOR_FILTER, AXIS_FILTER])
function_operation.set_apply_filter(TENSOR_AXIS_IN_RANGE_APPLY_FILTER)
elif group == filter_group.FilterGroup.TENSORSEQUENCE_AXIS_2:
function_operation.add_value_filters([TENSOR_SEQUENCE_FILTER, AXIS_FILTER])
elif group == filter_group.FilterGroup.TENSOR_BOOLTENSOR_2:
function_operation.add_value_filters(
[TENSOR_FILTER, get_dtype_filter(torch.bool)]
)
elif group == filter_group.FilterGroup.SAME_SHAPES_NUMERICTENSOR_2:
function_operation.add_value_filters([NUMERIC_TENSOR_FILTER] * 2)
function_operation.set_apply_filter(SAME_SHAPES_APPLY_FILTER)
elif group == filter_group.FilterGroup.SAME_DTYPE_NUMERIC_BROADCASTABLE_2:
function_operation.add_value_filters([NUMERIC_TENSOR_FILTER] * 2)
function_operation.set_apply_filter(SAME_DTYPES_BROADCASTABLE_APPLY_FILTER)
elif group == filter_group.FilterGroup.ELEMENTWISE_COMPARISON_2:
function_operation.add_value_filters(
[NUMERIC_TENSOR_FILTER, PRIMITIVE_OR_TENSOR_FILTER]
)
function_operation.set_apply_filter(BROADCASTABLE_APPLY_FILTER)
elif group == filter_group.FilterGroup.NE_BROADCASTABLE_2:
function_operation.add_value_filters(
[NUMERIC_TENSOR_FILTER, NONZERO_PRIMITIVE_OR_TENSOR_FILTER]
)
def _not_equal_broadcastable_filter(arg_values):
arg1, arg2 = arg_values
return (arg1 != arg2
and BROADCASTABLE_APPLY_FILTER(arg_values))
function_operation.set_apply_filter(_not_equal_broadcastable_filter)
# Operations with other special handling.
elif group == filter_group.FilterGroup.BINCOUNT_1:
def _bincount_filter(arg_value):
"""The value must contain nonnegative ints with a small maximum."""
# Must be an int tensor, lists of ints, or int primitive.
if not (
arg_value.is_tensor
and arg_value.has_int_dtype()
):
return False
max_value = arg_value.max()
min_value = arg_value.min()
return (min_value >= 0
and max_value <= limits.MAX_DIMENSION_LENGTH
and len(arg_value.shape) == 1)
function_operation.add_value_filters([_bincount_filter])
elif group == filter_group.FilterGroup.TENSORIZABLE_1:
def _tensorizable_filter(arg_value):
if arg_value.is_primitive:
return True
elif arg_value.is_sequence:
return not arg_value.elem_type_is_tensor
else:
return False
function_operation.add_value_filters([_tensorizable_filter])
elif group == filter_group.FilterGroup.BMM_2:
def _numeric_min_rank_3_filter(arg_value):
"""Must be an int or float tensor of rank = 3."""
return arg_value.is_tensor and len(arg_value.shape) == 3
def _bmm_filter(arg_values):
"""Ensures the third dimension of the first tensor equals to
the second dimension of the second tensor, and the first dimension
of the two argumetns should be equal."""
return (SAME_DTYPES_APPLY_FILTER(arg_values)
and arg_values[0].shape[2] == arg_values[1].shape[1]
and arg_values[0].shape[0] == arg_values[1].shape[0]
)
function_operation.add_value_filters([_numeric_min_rank_3_filter] * 2)
function_operation.set_apply_filter(_bmm_filter)
elif group == filter_group.FilterGroup.CAT_TENSORSEQUENCE_AXIS_2:
function_operation.add_value_filters([TENSOR_SEQUENCE_FILTER, AXIS_FILTER])
def _axis_in_range(arg_values):
"""Ensures the axis is at most the rank of the tensor."""
tensor, axis = arg_values
return axis.value < len(tensor.value[0].shape)
function_operation.set_apply_filter(_axis_in_range)
elif group == filter_group.FilterGroup.CDIST_2:
def _cdist_filter(arg_value):
return (arg_value.is_tensor
and arg_value.has_float_dtype()
and len(arg_value.shape) > 1)
function_operation.add_value_filters([_cdist_filter] * 2)
function_operation.set_apply_filter(SAME_SHAPES_APPLY_FILTER)
elif group == filter_group.FilterGroup.EYE_1:
function_operation.add_value_filters([SQUARE_MATRIX_SIZE_FILTER])
elif group == filter_group.FilterGroup.RANGE_1:
function_operation.add_value_filters([VECTOR_LENGTH_FILTER])
elif group == filter_group.FilterGroup.EXPAND_DIMS_2:
function_operation.add_value_filters([TENSOR_FILTER, AXIS_FILTER])
def _axis_in_range(arg_values):
"""Ensures the axis is at most the rank of the tensor."""
tensor, axis = arg_values
return axis.value < len(tensor.shape)
function_operation.set_apply_filter(_axis_in_range)
elif group == filter_group.FilterGroup.EXPAND_DIMS_ADDITIONAL_2:
function_operation.add_value_filters([TENSOR_FILTER, AXIS_FILTER])
def _axis_in_range(arg_values):
"""Ensures the axis is at most the rank of the tensor."""
tensor, axis = arg_values
return axis.value <= len(tensor.shape)
function_operation.set_apply_filter(_axis_in_range)
elif group == filter_group.FilterGroup.EYE_ROWS_COLS_2:
def _eye_rows_cols_apply_filter(arg_values):
"""Checks that the result will have a small number of elements."""
num_rows, num_cols = arg_values
return (
int(num_rows.value) * int(num_cols.value) <= limits.MAX_TENSOR_ELEMENTS
)
function_operation.add_value_filters([VECTOR_LENGTH_FILTER] * 2)
function_operation.set_apply_filter(_eye_rows_cols_apply_filter)
elif group == filter_group.FilterGroup.MATMUL_2:
def _numeric_min_rank_2_filter(arg_value):
"""Must be an int or float tensor of rank >= 2."""
return arg_value.is_tensor and len(arg_value.shape) >= 2
function_operation.add_value_filters([_numeric_min_rank_2_filter] * 2)
function_operation.set_apply_filter(SAME_DTYPES_APPLY_FILTER)
elif group == filter_group.FilterGroup.MM_2:
def _numeric_min_rank_2_filter(arg_value):
"""Must be an int or float tensor of rank = 2."""
return arg_value.is_tensor and len(arg_value.shape) == 2
def _mm_filter(arg_values):
"""Ensures the second dimension of the first tensor equals to
the first dimension of the second tensor."""
return (SAME_DTYPES_APPLY_FILTER(arg_values)
and arg_values[0].shape[1] == arg_values[1].shape[0]
)
function_operation.add_value_filters([_numeric_min_rank_2_filter] * 2)
function_operation.set_apply_filter(_mm_filter)
elif group == filter_group.FilterGroup.NORMALIZE_2:
def _complex_tensor_filter(arg_value):
return (arg_value.is_tensor
and arg_value.has_float_dtype())
function_operation.add_value_filters([_complex_tensor_filter, AXIS_FILTER])
def _axis_in_range(arg_values):
"""Ensures the axis is at most the rank of the tensor."""
tensor, axis = arg_values
return axis.value < len(tensor.shape)
function_operation.set_apply_filter(_axis_in_range)
elif group == filter_group.FilterGroup.ONE_HOT_2:
def _one_hot_indices_filter(arg_value):
"""Must contain ints and less than the max number of dimensions."""
return (
arg_value.is_tensor
and arg_value.dtype == torch.int64
and arg_value.min() >= 0
and len(arg_value.shape) < limits.MAX_NUM_DIMENSIONS
)
def _one_hot_apply_filter(arg_values):
"""Checks that the result will have a small number of elements."""
indices, num_classes = arg_values
return (
indices.num_elements() * int(num_classes.value) <= limits.MAX_TENSOR_ELEMENTS
and indices.max() < num_classes.value
)
function_operation.add_value_filters(
[_one_hot_indices_filter, INT_LENGTH_FILTER]
)
function_operation.set_apply_filter(_one_hot_apply_filter)
elif group == filter_group.FilterGroup.PAD_2:
function_operation.add_value_filters([TENSOR_FILTER, PADDINGS_FILTER])
def _pad_2_apply_filter(arg_values):
tensor, paddings = arg_values
paddings_shape = paddings.sequence_shape
return (
tensor.shape
and paddings_shape[0] / 2 <= len(tensor.shape)
)
function_operation.set_apply_filter(_pad_2_apply_filter)
elif group == filter_group.FilterGroup.RESHAPE_2:
def _reshape_filter(arg_values):
"""The new size must be compatible with its original size."""
tensor, shape = arg_values
num_tensor_elements = torch.prod(torch.tensor(tensor.value.shape))
num_shape_elements = torch.prod(torch.tensor(shape.value))
return (num_tensor_elements % num_shape_elements == 0
and num_shape_elements != 1)
function_operation.add_value_filters([TENSOR_FILTER, SHAPE_FILTER])
function_operation.set_apply_filter(_reshape_filter)
elif group == filter_group.FilterGroup.SEARCHSORTED_2:
def _sorted_last_dimension(arg_value):
"""Must be a numeric tensor that is sorted in the last dimension."""
return (
NONSCALAR_NUMERIC_TENSOR_FILTER(arg_value)
and (
arg_value.has_float_dtype()
or arg_value.dtype in [torch.int32, torch.int64]
)
and bool(
torch.all(torch.eq(arg_value.value, torch.sort(arg_value.value)[0]))
)
)
function_operation.add_value_filters(
[_sorted_last_dimension, NUMERIC_PRIMITIVE_OR_TENSOR_FILTER]
)
def _searchsorted_apply_filter(arg_values):
"""DTypes must match, dimension lengths equal except the last."""
sorted_sequence, values = arg_values
return (
sorted_sequence.dtype == values.dtype
and len(sorted_sequence.shape) == len(values.shape)
and sorted_sequence.shape[:-1] == values.shape[:-1]
)
function_operation.set_apply_filter(_searchsorted_apply_filter)
elif group == filter_group.FilterGroup.TILE_2:
def _tile_apply_filter(arg_values):
"""Checks that the result will have a small number of elements."""
tensor, multiples = arg_values
return (
multiples.min() > 0
and multiples.max() > 1
and multiples.reduce_prod() * tensor.num_elements()
<= limits.MAX_TENSOR_ELEMENTS
)
function_operation.add_value_filters([TENSOR_FILTER, AXIS_SEQUENCE_FILTER])
function_operation.set_apply_filter(_tile_apply_filter)
elif group == filter_group.FilterGroup.SQUEEZE_2:
def _very_squeezable_filter(arg_value):
"""Keeps tensors with more than 1 squeezable dimension."""
# If a tensor only has 1 squeezable dimension, then this operation is
# useless because it is simpler to use the one-arg version of squeeze.
return TENSOR_FILTER(arg_value) and (arg_value.shape or []).count(1) >= 2
function_operation.add_value_filters([_very_squeezable_filter, AXIS_FILTER])
def _squeeze_2_apply_filter(arg_values):
tensor, axis = arg_values
return axis.value < len(tensor.shape) and tensor.shape[axis.value] == 1
function_operation.set_apply_filter(_squeeze_2_apply_filter)
elif group == filter_group.FilterGroup.GATHER_3:
function_operation.add_value_filters(
[
NON_SCALAR_TENSOR_FILTER,
BATCH_DIMS_FILTER,
GATHER_INDICES_FILTER,
]
)
def _gather_3_apply_filter(arg_values):
params, batch_dims, indices = arg_values
batch_dims_int = batch_dims.value
indices_shape = (
indices.shape if indices.is_tensor else indices.sequence_shape
)
return (
indices.is_tensor
and batch_dims_int < min(len(indices_shape), len(params.shape))
and params.shape[:batch_dims_int] == indices_shape[:batch_dims_int]
and indices_shape
# It is also required that index.size(d) <= input.size(d) for all dimensions d != dim
and all([(indices_shape[d] <= params.shape[d]) or d == batch_dims_int for d in range(min(len(params.shape), len(indices_shape)))])
and indices.max() < params.shape[batch_dims_int]
and
# Upper bound on resulting tensor size.
indices.num_elements() * params.num_elements()
<= limits.MAX_TENSOR_ELEMENTS
)
function_operation.set_apply_filter(_gather_3_apply_filter)
elif group == filter_group.FilterGroup.INDEX_SELECT_3:
function_operation.add_value_filters(
[
NON_SCALAR_TENSOR_FILTER,
BATCH_DIMS_FILTER,
INDICES_FILTER,
]
)
def _index_select_3_apply_filter(arg_values):
params, dim, indices = arg_values
dim_int = dim.value
indices_shape = indices.shape
return (
dim_int < len(params.shape)
and indices_shape
and indices.max() < max(params.shape)
and
# Upper bound on resulting tensor size.
indices.num_elements() * params.num_elements()
<= limits.MAX_TENSOR_ELEMENTS
)
function_operation.set_apply_filter(_index_select_3_apply_filter)
elif group == filter_group.FilterGroup.RANGE_3:
def _range_3_apply_filter(arg_values):
"""Checks that the range will end up having a small number of elements."""
start, limit, delta = arg_values
return (
delta.value != 0
and 0
< len(range(start.value, limit.value, delta.value))
<= limits.MAX_DIMENSION_LENGTH
)
function_operation.add_value_filters([get_type_filter(int)] * 3)
function_operation.set_apply_filter(_range_3_apply_filter)
elif group == filter_group.FilterGroup.REPEAT_3:
def _repeat_filter(arg_value):
return (INT_OR_INT_TENSOR_FILTER(arg_value)
and arg_value.min() > 0)
def _repeat_3_apply_filter(arg_values):
"""Checks the first two arguments are broadcastable
and the third argument is at most the rank of the tensor."""
return (BROADCASTABLE_APPLY_FILTER([arg_values[0], arg_values[1]])
and TENSOR_AXIS_IN_RANGE_APPLY_FILTER([arg_values[0], arg_values[2]]))
function_operation.add_value_filters([NUMERIC_TENSOR_FILTER, _repeat_filter, AXIS_FILTER])
function_operation.set_apply_filter(_repeat_3_apply_filter)
elif group == filter_group.FilterGroup.ROLL_3:
# The case where the shift and axis are both single integers.
function_operation.add_value_filters(
[TENSOR_FILTER, INT_OR_INT_TENSOR_FILTER, AXIS_FILTER]
)
# The case where the shift and axis are both sequences of integers.
function_operation.add_value_filters(
[TENSOR_FILTER, INTS_SEQUENCE_FILTER, AXIS_SEQUENCE_FILTER]
)
def _roll_apply_filter(arg_values):
tensor, shift, axis = arg_values
if axis.type is int:
return axis.value < len(tensor.shape)
else:
return len(axis.value) == len(shift.value) and axis.max() < len(
tensor.shape
)
function_operation.set_apply_filter(_roll_apply_filter)
elif group == filter_group.FilterGroup.TENSORDOT_3:
def _tensordot_arg_3_filter(arg_value):
"""The argument "axes" must have axis-like ints and the right shape."""
if arg_value.type is int:
# An int N means "sum over the last N axes of a and the first N axes of
# b in order", so 0 <= N <= maximum rank.
return 0 <= arg_value.value <= limits.MAX_NUM_DIMENSIONS
if arg_value.elem_type is int:
# List of length 2 is ok, elements must be valid axes.
return (
len(arg_value.value) == 2
and 0 <= arg_value.min()
and arg_value.max() < limits.MAX_NUM_DIMENSIONS
)
# Otherwise, must be an int tensor of shape [2] or [2, k].
return (
arg_value.is_tensor
and arg_value.has_int_dtype()
and 1 <= len(arg_value.shape) <= 2
and arg_value.shape[0] == 2
and 0 <= arg_value.min()
and arg_value.max() < limits.MAX_NUM_DIMENSIONS
)
function_operation.add_value_filters(
[
NONSCALAR_NUMERIC_TENSOR_FILTER,
NONSCALAR_NUMERIC_TENSOR_FILTER,
_tensordot_arg_3_filter,
]
)
def _tensordot_apply_filter(arg_value):
"""First two tensors must have same dtype, and axes must be in range."""
a, b, axes = arg_value
if (
a.dtype != b.dtype
or
# This check is overly conservative for the sake of efficiency; the
# resulting number of elements is most likely smaller but will take
# effort to compute more precisely.
a.num_elements() * b.num_elements() > limits.MAX_TENSOR_ELEMENTS
):
return False
a_rank = len(a.shape)
b_rank = len(b.shape)
min_rank = min(a_rank, b_rank)
if axes.type is int:
return axes.value <= min_rank
elif axes.is_sequence or len(axes.shape) == 1:
# axes is a list or tensor of shape [2].
return axes.max() < min_rank
else: # axes is a tensor of shape [2, k].
return (
axes.shape[1] <= min_rank
and tf_coder_utils.max_tensor_value(axes.value[0]) < a_rank
and tf_coder_utils.max_tensor_value(axes.value[1]) < b_rank
)
function_operation.set_apply_filter(_tensordot_apply_filter)
elif group == filter_group.FilterGroup.TRANSPOSE_3:
def _transpose_3_apply_filter(arg_values):
"""Checks that perm has length equal to the number of a's dimensions."""
tensor, dim0, dim1 = arg_values
return (dim0.value < len(tensor.shape)
and dim1.value < len(tensor.shape)
and dim0.value < dim1.value)
function_operation.add_value_filters(
[TENSOR_FILTER, BATCH_DIMS_FILTER, BATCH_DIMS_FILTER]
)
function_operation.set_apply_filter(_transpose_3_apply_filter)
elif group == filter_group.FilterGroup.WHERE_TENSOR_3:
def _where_apply_filter(arg_values):
"""Ensures that the last two arguments have matching shapes and dtypes."""
condition, x, y = arg_values
return (TENSOR_PRIMITIVE_SAME_TYPES_APPLY_FILTER([x, y])
and broadcastable(condition.shape, x.shape)
and broadcastable(condition.shape, y.shape)
and x != y)
function_operation.add_value_filters(
[
get_dtype_filter(torch.bool),
NUMERIC_TENSOR_FILTER,
NUMERIC_PRIMITIVE_OR_TENSOR_FILTER,
]
)
function_operation.set_apply_filter(_where_apply_filter)
elif group == filter_group.FilterGroup.WHERE_NUMERIC_3:
def _where_apply_filter(arg_values):
"""Ensures that the last two arguments have matching shapes and dtypes."""
condition, x, y = arg_values
return (TENSOR_PRIMITIVE_SAME_TYPES_APPLY_FILTER([x, y])
and broadcastable(condition.shape, x.shape)
and broadcastable(condition.shape, y.shape)
and x != y)
function_operation.add_value_filters(
[
get_dtype_filter(torch.bool),
NUMERIC_PRIMITIVE_FILTER,
NUMERIC_PRIMITIVE_OR_TENSOR_FILTER,
]
)
function_operation.set_apply_filter(_where_apply_filter)
else:
raise ValueError(
"Unknown filter group {} for FunctionOperation {}.".format(
group, function_operation.name
)
)