def add_filters_to_function_operation()

in Synthesis_incorporation/value_search/operation_filtering.py [0:0]


def add_filters_to_function_operation(function_operation):
    """Adds filters to the FunctionOperation depending on its FilterGroup."""
    group = function_operation.function_info.filter_group

    if group == filter_group.FilterGroup.NONE:
        # Do nothing.
        pass

    elif group == filter_group.FilterGroup.SHAPE_1:
        function_operation.add_value_filters([SHAPE_FILTER])
    elif group == filter_group.FilterGroup.TENSOR_1:
        function_operation.add_value_filters([TENSOR_FILTER])
    elif group == filter_group.FilterGroup.TENSORSEQUENCE_1:
        function_operation.add_value_filters([TENSOR_SEQUENCE_FILTER])
    elif group == filter_group.FilterGroup.FLOATTENSOR_1:
        function_operation.add_value_filters([FLOAT_TENSOR_FILTER])
    elif group == filter_group.FilterGroup.NUMERICTENSOR_1:
        function_operation.add_value_filters([NUMERIC_TENSOR_FILTER])
    elif group == filter_group.FilterGroup.PRIMITIVE_OR_TENSOR_1:
        function_operation.add_value_filters([PRIMITIVE_OR_TENSOR_FILTER])

    elif group == filter_group.FilterGroup.TENSOR_AXIS_2:
        function_operation.add_value_filters([TENSOR_FILTER, AXIS_FILTER])
        function_operation.set_apply_filter(TENSOR_AXIS_IN_RANGE_APPLY_FILTER)
    elif group == filter_group.FilterGroup.NUMERICTENSOR_AXIS_2:
        function_operation.add_value_filters([NUMERIC_TENSOR_FILTER, AXIS_FILTER])
        function_operation.set_apply_filter(TENSOR_AXIS_IN_RANGE_APPLY_FILTER)
    elif group == filter_group.FilterGroup.TENSORSEQUENCE_AXIS_2:
        function_operation.add_value_filters([TENSOR_SEQUENCE_FILTER, AXIS_FILTER])
    elif group == filter_group.FilterGroup.TENSOR_BOOLTENSOR_2:
        function_operation.add_value_filters(
            [TENSOR_FILTER, get_dtype_filter(torch.bool)]
        )
    elif group == filter_group.FilterGroup.SAME_SHAPES_NUMERICTENSOR_2:
        function_operation.add_value_filters([NUMERIC_TENSOR_FILTER] * 2)
        function_operation.set_apply_filter(SAME_SHAPES_APPLY_FILTER)
    elif group == filter_group.FilterGroup.SAME_DTYPE_NUMERIC_BROADCASTABLE_2:
        function_operation.add_value_filters([NUMERIC_TENSOR_FILTER] * 2)
        function_operation.set_apply_filter(SAME_DTYPES_BROADCASTABLE_APPLY_FILTER)
    elif group == filter_group.FilterGroup.ELEMENTWISE_COMPARISON_2:
        function_operation.add_value_filters(
            [NUMERIC_TENSOR_FILTER, PRIMITIVE_OR_TENSOR_FILTER]
        )
        function_operation.set_apply_filter(BROADCASTABLE_APPLY_FILTER)
    elif group == filter_group.FilterGroup.NE_BROADCASTABLE_2:
        function_operation.add_value_filters(
            [NUMERIC_TENSOR_FILTER, NONZERO_PRIMITIVE_OR_TENSOR_FILTER]
        )
        def _not_equal_broadcastable_filter(arg_values):
            arg1, arg2 = arg_values
            return (arg1 != arg2
                and BROADCASTABLE_APPLY_FILTER(arg_values))
        function_operation.set_apply_filter(_not_equal_broadcastable_filter)

    # Operations with other special handling.

    elif group == filter_group.FilterGroup.BINCOUNT_1:

        def _bincount_filter(arg_value):
            """The value must contain nonnegative ints with a small maximum."""
            # Must be an int tensor, lists of ints, or int primitive.
            if not (
                arg_value.is_tensor
                and arg_value.has_int_dtype()
            ):
                return False
            max_value = arg_value.max()
            min_value = arg_value.min()
            return (min_value >= 0
                and max_value <= limits.MAX_DIMENSION_LENGTH
                and len(arg_value.shape) == 1)

        function_operation.add_value_filters([_bincount_filter])

    elif group == filter_group.FilterGroup.TENSORIZABLE_1:
        def _tensorizable_filter(arg_value):
            if arg_value.is_primitive:
                return True
            elif arg_value.is_sequence:
                return not arg_value.elem_type_is_tensor
            else:
                return False
        function_operation.add_value_filters([_tensorizable_filter])

    elif group == filter_group.FilterGroup.BMM_2:
        def _numeric_min_rank_3_filter(arg_value):
            """Must be an int or float tensor of rank = 3."""
            return arg_value.is_tensor and len(arg_value.shape) == 3

        def _bmm_filter(arg_values):
            """Ensures the third dimension of the first tensor equals to
            the second dimension of the second tensor, and the first dimension
            of the two argumetns should be equal."""
            return (SAME_DTYPES_APPLY_FILTER(arg_values)
                    and arg_values[0].shape[2] == arg_values[1].shape[1]
                    and arg_values[0].shape[0] == arg_values[1].shape[0]
            )

        function_operation.add_value_filters([_numeric_min_rank_3_filter] * 2)
        function_operation.set_apply_filter(_bmm_filter)



    elif group == filter_group.FilterGroup.CAT_TENSORSEQUENCE_AXIS_2:
        function_operation.add_value_filters([TENSOR_SEQUENCE_FILTER, AXIS_FILTER])
        def _axis_in_range(arg_values):
            """Ensures the axis is at most the rank of the tensor."""
            tensor, axis = arg_values
            return axis.value < len(tensor.value[0].shape)

        function_operation.set_apply_filter(_axis_in_range)

    elif group == filter_group.FilterGroup.CDIST_2:
        def _cdist_filter(arg_value):
            return (arg_value.is_tensor
                    and arg_value.has_float_dtype()
                    and len(arg_value.shape) > 1)
        function_operation.add_value_filters([_cdist_filter] * 2)
        function_operation.set_apply_filter(SAME_SHAPES_APPLY_FILTER)

    elif group == filter_group.FilterGroup.EYE_1:
        function_operation.add_value_filters([SQUARE_MATRIX_SIZE_FILTER])

    elif group == filter_group.FilterGroup.RANGE_1:
        function_operation.add_value_filters([VECTOR_LENGTH_FILTER])

    elif group == filter_group.FilterGroup.EXPAND_DIMS_2:
        function_operation.add_value_filters([TENSOR_FILTER, AXIS_FILTER])

        def _axis_in_range(arg_values):
            """Ensures the axis is at most the rank of the tensor."""
            tensor, axis = arg_values
            return axis.value < len(tensor.shape)

        function_operation.set_apply_filter(_axis_in_range)

    elif group == filter_group.FilterGroup.EXPAND_DIMS_ADDITIONAL_2:
        function_operation.add_value_filters([TENSOR_FILTER, AXIS_FILTER])

        def _axis_in_range(arg_values):
            """Ensures the axis is at most the rank of the tensor."""
            tensor, axis = arg_values
            return axis.value <= len(tensor.shape)

        function_operation.set_apply_filter(_axis_in_range)

    elif group == filter_group.FilterGroup.EYE_ROWS_COLS_2:

        def _eye_rows_cols_apply_filter(arg_values):
            """Checks that the result will have a small number of elements."""
            num_rows, num_cols = arg_values
            return (
                int(num_rows.value) * int(num_cols.value) <= limits.MAX_TENSOR_ELEMENTS
            )

        function_operation.add_value_filters([VECTOR_LENGTH_FILTER] * 2)
        function_operation.set_apply_filter(_eye_rows_cols_apply_filter)

    elif group == filter_group.FilterGroup.MATMUL_2:

        def _numeric_min_rank_2_filter(arg_value):
            """Must be an int or float tensor of rank >= 2."""
            return arg_value.is_tensor and len(arg_value.shape) >= 2

        function_operation.add_value_filters([_numeric_min_rank_2_filter] * 2)
        function_operation.set_apply_filter(SAME_DTYPES_APPLY_FILTER)

    elif group == filter_group.FilterGroup.MM_2:

        def _numeric_min_rank_2_filter(arg_value):
            """Must be an int or float tensor of rank = 2."""
            return arg_value.is_tensor and len(arg_value.shape) == 2

        def _mm_filter(arg_values):
            """Ensures the second dimension of the first tensor equals to
            the first dimension of the second tensor."""
            return (SAME_DTYPES_APPLY_FILTER(arg_values)
                    and arg_values[0].shape[1] == arg_values[1].shape[0]
            )

        function_operation.add_value_filters([_numeric_min_rank_2_filter] * 2)
        function_operation.set_apply_filter(_mm_filter)


    elif group == filter_group.FilterGroup.NORMALIZE_2:
        def _complex_tensor_filter(arg_value):
            return (arg_value.is_tensor
                    and arg_value.has_float_dtype())
        function_operation.add_value_filters([_complex_tensor_filter, AXIS_FILTER])

        def _axis_in_range(arg_values):
            """Ensures the axis is at most the rank of the tensor."""
            tensor, axis = arg_values
            return axis.value < len(tensor.shape)

        function_operation.set_apply_filter(_axis_in_range)

    elif group == filter_group.FilterGroup.ONE_HOT_2:

        def _one_hot_indices_filter(arg_value):
            """Must contain ints and less than the max number of dimensions."""
            return (
                arg_value.is_tensor
                and arg_value.dtype == torch.int64
                and arg_value.min() >= 0
                and len(arg_value.shape) < limits.MAX_NUM_DIMENSIONS
            )

        def _one_hot_apply_filter(arg_values):
            """Checks that the result will have a small number of elements."""
            indices, num_classes = arg_values
            return (
                indices.num_elements() * int(num_classes.value) <= limits.MAX_TENSOR_ELEMENTS
                and indices.max() < num_classes.value
            )

        function_operation.add_value_filters(
            [_one_hot_indices_filter, INT_LENGTH_FILTER]
        )
        function_operation.set_apply_filter(_one_hot_apply_filter)

    elif group == filter_group.FilterGroup.PAD_2:
        function_operation.add_value_filters([TENSOR_FILTER, PADDINGS_FILTER])

        def _pad_2_apply_filter(arg_values):
            tensor, paddings = arg_values
            paddings_shape = paddings.sequence_shape
            return (
                tensor.shape
                and paddings_shape[0] / 2 <= len(tensor.shape)
            )

        function_operation.set_apply_filter(_pad_2_apply_filter)

    elif group == filter_group.FilterGroup.RESHAPE_2:
        def _reshape_filter(arg_values):
            """The new size must be compatible with its original size."""
            tensor, shape = arg_values
            num_tensor_elements = torch.prod(torch.tensor(tensor.value.shape))
            num_shape_elements = torch.prod(torch.tensor(shape.value))
            return (num_tensor_elements % num_shape_elements == 0
                    and num_shape_elements != 1)
        function_operation.add_value_filters([TENSOR_FILTER, SHAPE_FILTER])
        function_operation.set_apply_filter(_reshape_filter)

    elif group == filter_group.FilterGroup.SEARCHSORTED_2:

        def _sorted_last_dimension(arg_value):
            """Must be a numeric tensor that is sorted in the last dimension."""
            return (
                NONSCALAR_NUMERIC_TENSOR_FILTER(arg_value)
                and (
                    arg_value.has_float_dtype()
                    or arg_value.dtype in [torch.int32, torch.int64]
                )
                and bool(
                    torch.all(torch.eq(arg_value.value, torch.sort(arg_value.value)[0]))
                )
            )

        function_operation.add_value_filters(
            [_sorted_last_dimension, NUMERIC_PRIMITIVE_OR_TENSOR_FILTER]
        )

        def _searchsorted_apply_filter(arg_values):
            """DTypes must match, dimension lengths equal except the last."""
            sorted_sequence, values = arg_values
            return (
                sorted_sequence.dtype == values.dtype
                and len(sorted_sequence.shape) == len(values.shape)
                and sorted_sequence.shape[:-1] == values.shape[:-1]
            )

        function_operation.set_apply_filter(_searchsorted_apply_filter)

    elif group == filter_group.FilterGroup.TILE_2:

        def _tile_apply_filter(arg_values):
            """Checks that the result will have a small number of elements."""
            tensor, multiples = arg_values
            return (
                multiples.min() > 0
                and multiples.max() > 1
                and multiples.reduce_prod() * tensor.num_elements()
                <= limits.MAX_TENSOR_ELEMENTS
            )

        function_operation.add_value_filters([TENSOR_FILTER, AXIS_SEQUENCE_FILTER])
        function_operation.set_apply_filter(_tile_apply_filter)

    elif group == filter_group.FilterGroup.SQUEEZE_2:

        def _very_squeezable_filter(arg_value):
            """Keeps tensors with more than 1 squeezable dimension."""
            # If a tensor only has 1 squeezable dimension, then this operation is
            # useless because it is simpler to use the one-arg version of squeeze.
            return TENSOR_FILTER(arg_value) and (arg_value.shape or []).count(1) >= 2

        function_operation.add_value_filters([_very_squeezable_filter, AXIS_FILTER])

        def _squeeze_2_apply_filter(arg_values):
            tensor, axis = arg_values
            return axis.value < len(tensor.shape) and tensor.shape[axis.value] == 1

        function_operation.set_apply_filter(_squeeze_2_apply_filter)

    elif group == filter_group.FilterGroup.GATHER_3:
        function_operation.add_value_filters(
            [
                NON_SCALAR_TENSOR_FILTER,
                BATCH_DIMS_FILTER,
                GATHER_INDICES_FILTER,
            ]
        )

        def _gather_3_apply_filter(arg_values):
            params, batch_dims, indices = arg_values
            batch_dims_int = batch_dims.value
            indices_shape = (
                indices.shape if indices.is_tensor else indices.sequence_shape
            )
            return (
                indices.is_tensor
                and batch_dims_int < min(len(indices_shape), len(params.shape))
                and params.shape[:batch_dims_int] == indices_shape[:batch_dims_int]
                and indices_shape
                # It is also required that index.size(d) <= input.size(d) for all dimensions d != dim
                and all([(indices_shape[d] <= params.shape[d]) or d == batch_dims_int for d in range(min(len(params.shape), len(indices_shape)))])
                and indices.max() < params.shape[batch_dims_int]
                and
                # Upper bound on resulting tensor size.
                indices.num_elements() * params.num_elements()
                <= limits.MAX_TENSOR_ELEMENTS
            )

        function_operation.set_apply_filter(_gather_3_apply_filter)

    elif group == filter_group.FilterGroup.INDEX_SELECT_3:
        function_operation.add_value_filters(
            [
                NON_SCALAR_TENSOR_FILTER,
                BATCH_DIMS_FILTER,
                INDICES_FILTER,
            ]
        )

        def _index_select_3_apply_filter(arg_values):
            params, dim, indices = arg_values
            dim_int = dim.value
            indices_shape = indices.shape
            return (
                dim_int < len(params.shape)
                and indices_shape
                and indices.max() < max(params.shape)
                and
                # Upper bound on resulting tensor size.
                indices.num_elements() * params.num_elements()
                <= limits.MAX_TENSOR_ELEMENTS
            )

        function_operation.set_apply_filter(_index_select_3_apply_filter)

    elif group == filter_group.FilterGroup.RANGE_3:

        def _range_3_apply_filter(arg_values):
            """Checks that the range will end up having a small number of elements."""
            start, limit, delta = arg_values
            return (
                delta.value != 0
                and 0
                < len(range(start.value, limit.value, delta.value))
                <= limits.MAX_DIMENSION_LENGTH
            )

        function_operation.add_value_filters([get_type_filter(int)] * 3)
        function_operation.set_apply_filter(_range_3_apply_filter)

    elif group == filter_group.FilterGroup.REPEAT_3:
        def _repeat_filter(arg_value):
            return (INT_OR_INT_TENSOR_FILTER(arg_value)
                    and arg_value.min() > 0)

        def _repeat_3_apply_filter(arg_values):
            """Checks the first two arguments are broadcastable
            and the third argument is at most the rank of the tensor."""
            return (BROADCASTABLE_APPLY_FILTER([arg_values[0], arg_values[1]])
                and TENSOR_AXIS_IN_RANGE_APPLY_FILTER([arg_values[0], arg_values[2]]))
        function_operation.add_value_filters([NUMERIC_TENSOR_FILTER, _repeat_filter, AXIS_FILTER])
        function_operation.set_apply_filter(_repeat_3_apply_filter)

    elif group == filter_group.FilterGroup.ROLL_3:
        # The case where the shift and axis are both single integers.
        function_operation.add_value_filters(
            [TENSOR_FILTER, INT_OR_INT_TENSOR_FILTER, AXIS_FILTER]
        )
        # The case where the shift and axis are both sequences of integers.
        function_operation.add_value_filters(
            [TENSOR_FILTER, INTS_SEQUENCE_FILTER, AXIS_SEQUENCE_FILTER]
        )

        def _roll_apply_filter(arg_values):
            tensor, shift, axis = arg_values
            if axis.type is int:
                return axis.value < len(tensor.shape)
            else:
                return len(axis.value) == len(shift.value) and axis.max() < len(
                    tensor.shape
                )

        function_operation.set_apply_filter(_roll_apply_filter)

    elif group == filter_group.FilterGroup.TENSORDOT_3:

        def _tensordot_arg_3_filter(arg_value):
            """The argument "axes" must have axis-like ints and the right shape."""
            if arg_value.type is int:
                # An int N means "sum over the last N axes of a and the first N axes of
                # b in order", so 0 <= N <= maximum rank.
                return 0 <= arg_value.value <= limits.MAX_NUM_DIMENSIONS
            if arg_value.elem_type is int:
                # List of length 2 is ok, elements must be valid axes.
                return (
                    len(arg_value.value) == 2
                    and 0 <= arg_value.min()
                    and arg_value.max() < limits.MAX_NUM_DIMENSIONS
                )
            # Otherwise, must be an int tensor of shape [2] or [2, k].
            return (
                arg_value.is_tensor
                and arg_value.has_int_dtype()
                and 1 <= len(arg_value.shape) <= 2
                and arg_value.shape[0] == 2
                and 0 <= arg_value.min()
                and arg_value.max() < limits.MAX_NUM_DIMENSIONS
            )

        function_operation.add_value_filters(
            [
                NONSCALAR_NUMERIC_TENSOR_FILTER,
                NONSCALAR_NUMERIC_TENSOR_FILTER,
                _tensordot_arg_3_filter,
            ]
        )

        def _tensordot_apply_filter(arg_value):
            """First two tensors must have same dtype, and axes must be in range."""
            a, b, axes = arg_value
            if (
                a.dtype != b.dtype
                or
                # This check is overly conservative for the sake of efficiency; the
                # resulting number of elements is most likely smaller but will take
                # effort to compute more precisely.
                a.num_elements() * b.num_elements() > limits.MAX_TENSOR_ELEMENTS
            ):
                return False
            a_rank = len(a.shape)
            b_rank = len(b.shape)
            min_rank = min(a_rank, b_rank)
            if axes.type is int:
                return axes.value <= min_rank
            elif axes.is_sequence or len(axes.shape) == 1:
                # axes is a list or tensor of shape [2].
                return axes.max() < min_rank
            else:  # axes is a tensor of shape [2, k].
                return (
                    axes.shape[1] <= min_rank
                    and tf_coder_utils.max_tensor_value(axes.value[0]) < a_rank
                    and tf_coder_utils.max_tensor_value(axes.value[1]) < b_rank
                )

        function_operation.set_apply_filter(_tensordot_apply_filter)

    elif group == filter_group.FilterGroup.TRANSPOSE_3:

        def _transpose_3_apply_filter(arg_values):
            """Checks that perm has length equal to the number of a's dimensions."""
            tensor, dim0, dim1 = arg_values
            return (dim0.value < len(tensor.shape)
                    and dim1.value < len(tensor.shape)
                    and dim0.value < dim1.value)

        function_operation.add_value_filters(
            [TENSOR_FILTER, BATCH_DIMS_FILTER, BATCH_DIMS_FILTER]
        )
        function_operation.set_apply_filter(_transpose_3_apply_filter)

    elif group == filter_group.FilterGroup.WHERE_TENSOR_3:

        def _where_apply_filter(arg_values):
            """Ensures that the last two arguments have matching shapes and dtypes."""
            condition, x, y = arg_values
            return (TENSOR_PRIMITIVE_SAME_TYPES_APPLY_FILTER([x, y])
                    and broadcastable(condition.shape, x.shape)
                    and broadcastable(condition.shape, y.shape)
                    and x != y)

        function_operation.add_value_filters(
            [
                get_dtype_filter(torch.bool),
                NUMERIC_TENSOR_FILTER,
                NUMERIC_PRIMITIVE_OR_TENSOR_FILTER,
            ]
        )
        function_operation.set_apply_filter(_where_apply_filter)

    elif group == filter_group.FilterGroup.WHERE_NUMERIC_3:

        def _where_apply_filter(arg_values):
            """Ensures that the last two arguments have matching shapes and dtypes."""
            condition, x, y = arg_values
            return (TENSOR_PRIMITIVE_SAME_TYPES_APPLY_FILTER([x, y])
                    and broadcastable(condition.shape, x.shape)
                    and broadcastable(condition.shape, y.shape)
                    and x != y)

        function_operation.add_value_filters(
            [
                get_dtype_filter(torch.bool),
                NUMERIC_PRIMITIVE_FILTER,
                NUMERIC_PRIMITIVE_OR_TENSOR_FILTER,
            ]
        )
        function_operation.set_apply_filter(_where_apply_filter)

    else:
        raise ValueError(
            "Unknown filter group {} for FunctionOperation {}.".format(
                group, function_operation.name
            )
        )