def is_castable()

in Synthesis_incorporation/value_search/operation_filtering.py [0:0]


def is_castable(to_cast, dtype):
    """Returns whether `to_cast` (a Value) can be safely casted to the dtype.

    This filtering strategy is a workaround for undefined behavior in TensorFlow
    (b/119633897).

    Args:
      to_cast: A Value object that would be casted.
      dtype: A Value containing a torch.dtype that `to_cast` would be casted to.
    """
    if not dtype.is_int_dtype():
        return True  # We can always cast to a non-int dtype.

    to_cast_value = to_cast.value
    if to_cast.is_sparse_tensor:
        to_cast_value = to_cast.value.values

    if to_cast.is_tensor or to_cast.is_sparse_tensor:
        if not to_cast.has_float_dtype():
            return True  # Only float -> int is potentially unsafe.
        if not _check_tensor_finite(to_cast_value):
            return False  # Non-finite floats cannot be casted to int dtypes.
    elif to_cast.is_sequence:
        if to_cast.elem_type is float:
            if float("nan") in to_cast_value:
                return False  # inf and -inf will be caught by the min/max logic.
        elif to_cast.elem_type_is_tensor:
            return all(
                element.size()
                and is_castable(value_module.InputValue(element, "dummy"), to_cast)
                for element in to_cast_value
            )
        elif to_cast.elem_type_is_sparse_tensor:
            return all(
                element.values.size()
                and is_castable(value_module.InputValue(element, "dummy"), to_cast)
                for element in to_cast_value
            )
        else:
            return True  # Only lists of floats or float tensors can be unsafe.
    elif to_cast.type is float:
        if math.isnan(to_cast_value):
            return False
    else:
        return True

    min_int, max_int = tf_coder_utils.INT_DTYPE_MIN_MAX[dtype.value]

    # Floats are truncated when casted to int (nearest int in the zero direction).
    # Assuming min_int <= 0, the minimum safe float is (min_int - 1 + epsilon),
    # and the maximum safe float is (max_int + 1 - epsilon).
    return to_cast.min() > min_int - 1 and to_cast.max() < max_int + 1