in Synthesis_incorporation/value_search/operation_filtering.py [0:0]
def is_castable(to_cast, dtype):
"""Returns whether `to_cast` (a Value) can be safely casted to the dtype.
This filtering strategy is a workaround for undefined behavior in TensorFlow
(b/119633897).
Args:
to_cast: A Value object that would be casted.
dtype: A Value containing a torch.dtype that `to_cast` would be casted to.
"""
if not dtype.is_int_dtype():
return True # We can always cast to a non-int dtype.
to_cast_value = to_cast.value
if to_cast.is_sparse_tensor:
to_cast_value = to_cast.value.values
if to_cast.is_tensor or to_cast.is_sparse_tensor:
if not to_cast.has_float_dtype():
return True # Only float -> int is potentially unsafe.
if not _check_tensor_finite(to_cast_value):
return False # Non-finite floats cannot be casted to int dtypes.
elif to_cast.is_sequence:
if to_cast.elem_type is float:
if float("nan") in to_cast_value:
return False # inf and -inf will be caught by the min/max logic.
elif to_cast.elem_type_is_tensor:
return all(
element.size()
and is_castable(value_module.InputValue(element, "dummy"), to_cast)
for element in to_cast_value
)
elif to_cast.elem_type_is_sparse_tensor:
return all(
element.values.size()
and is_castable(value_module.InputValue(element, "dummy"), to_cast)
for element in to_cast_value
)
else:
return True # Only lists of floats or float tensors can be unsafe.
elif to_cast.type is float:
if math.isnan(to_cast_value):
return False
else:
return True
min_int, max_int = tf_coder_utils.INT_DTYPE_MIN_MAX[dtype.value]
# Floats are truncated when casted to int (nearest int in the zero direction).
# Assuming min_int <= 0, the minimum safe float is (min_int - 1 + epsilon),
# and the maximum safe float is (max_int + 1 - epsilon).
return to_cast.min() > min_int - 1 and to_cast.max() < max_int + 1