in tensorflow_probability/python/math/interpolation.py [0:0]
def _interp_regular_1d_grid_impl(x,
x_ref_min,
x_ref_max,
y_ref,
axis=-1,
batch_y_ref=False,
fill_value='constant_extension',
fill_value_below=None,
fill_value_above=None,
grid_regularizing_transform=None,
name=None):
"""1-D interpolation that works with/without batching."""
# Note: we do *not* make the no-batch version a special case of the batch
# version, because that would an inefficient use of batch_gather with
# unnecessarily broadcast args.
with tf.name_scope(name or 'interp_regular_1d_grid_impl'):
# Arg checking.
allowed_fv_st = ('constant_extension', 'extrapolate')
for fv in (fill_value, fill_value_below, fill_value_above):
if isinstance(fv, str) and fv not in allowed_fv_st:
raise ValueError(
'A fill value ({}) was not an allowed string ({})'.format(
fv, allowed_fv_st))
# Separate value fills for below/above incurs extra cost, so keep track of
# whether this is needed.
need_separate_fills = (
fill_value_above is not None or fill_value_below is not None or
fill_value == 'extrapolate' # always requries separate below/above
)
if need_separate_fills and fill_value_above is None:
fill_value_above = fill_value
if need_separate_fills and fill_value_below is None:
fill_value_below = fill_value
dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref],
dtype_hint=tf.float32)
x = tf.convert_to_tensor(x, name='x', dtype=dtype)
x_ref_min = tf.convert_to_tensor(
x_ref_min, name='x_ref_min', dtype=dtype)
x_ref_max = tf.convert_to_tensor(
x_ref_max, name='x_ref_max', dtype=dtype)
if not batch_y_ref:
_assert_ndims_statically(x_ref_min, expect_ndims=0)
_assert_ndims_statically(x_ref_max, expect_ndims=0)
y_ref = tf.convert_to_tensor(y_ref, name='y_ref', dtype=dtype)
if batch_y_ref:
# If we're batching,
# x.shape ~ [A1,...,AN, D], x_ref_min/max.shape ~ [A1,...,AN]
# So to add together we'll append a singleton.
# If not batching, x_ref_min/max are scalar, so this isn't an issue,
# moreover, if not batching, x can be scalar, and expanding x_ref_min/max
# would cause a bad expansion of x when added to x (confused yet?).
x_ref_min = x_ref_min[..., tf.newaxis]
x_ref_max = x_ref_max[..., tf.newaxis]
axis = ps.convert_to_shape_tensor(axis, name='axis', dtype=tf.int32)
axis = ps.non_negative_axis(axis, ps.rank(y_ref))
_assert_ndims_statically(axis, expect_ndims=0)
ny = tf.cast(tf.shape(y_ref)[axis], dtype)
# Map [x_ref_min, x_ref_max] to [0, ny - 1].
# This is the (fractional) index of x.
if grid_regularizing_transform is None:
g = lambda x: x
else:
g = grid_regularizing_transform
fractional_idx = ((g(x) - g(x_ref_min)) / (g(x_ref_max) - g(x_ref_min)))
x_idx_unclipped = fractional_idx * (ny - 1)
# Wherever x is NaN, x_idx_unclipped will be NaN as well.
# Keep track of the nan indices here (so we can impute NaN later).
# Also eliminate any NaN indices, since there is not NaN in 32bit.
nan_idx = tf.math.is_nan(x_idx_unclipped)
zero = tf.zeros((), dtype=dtype)
x_idx_unclipped = tf.where(nan_idx, zero, x_idx_unclipped)
x_idx = tf.clip_by_value(x_idx_unclipped, zero, ny - 1)
# Get the index above and below x_idx.
# Naively we could set idx_below = floor(x_idx), idx_above = ceil(x_idx),
# however, this results in idx_below == idx_above whenever x is on a grid.
# This in turn results in y_ref_below == y_ref_above, and then the gradient
# at this point is zero. So here we 'jitter' one of idx_below, idx_above,
# so that they are at different values. This jittering does not affect the
# interpolated value, but does make the gradient nonzero (unless of course
# the y_ref values are the same).
idx_below = tf.floor(x_idx)
idx_above = tf.minimum(idx_below + 1, ny - 1)
idx_below = tf.maximum(idx_above - 1, 0)
# These are the values of y_ref corresponding to above/below indices.
idx_below_int32 = tf.cast(idx_below, dtype=tf.int32)
idx_above_int32 = tf.cast(idx_above, dtype=tf.int32)
if batch_y_ref:
# If y_ref.shape ~ [A1,...,AN, C, B1,...,BN],
# and x.shape, x_ref_min/max.shape ~ [A1,...,AN, D]
# Then y_ref_below.shape ~ [A1,...,AN, D, B1,...,BN]
y_ref_below = _batch_gather_with_broadcast(y_ref, idx_below_int32, axis)
y_ref_above = _batch_gather_with_broadcast(y_ref, idx_above_int32, axis)
else:
# Here, y_ref_below.shape =
# y_ref.shape[:axis] + x.shape + y_ref.shape[axis + 1:]
y_ref_below = tf.gather(y_ref, idx_below_int32, axis=axis)
y_ref_above = tf.gather(y_ref, idx_above_int32, axis=axis)
# Use t to get a convex combination of the below/above values.
t = x_idx - idx_below
# x, and tensors shaped like x, need to be added to, and selected with
# (using tf.where) the output y. This requires appending singletons.
# Make functions appropriate for batch/no-batch.
if batch_y_ref:
# In the non-batch case, the output shape is going to be
# y_ref.shape[:axis] + x.shape + y_ref.shape[axis+1:]
expand_x_fn = _make_expand_x_fn_for_batch_interpolation(y_ref, axis)
else:
# In the batch case, the output shape is going to be
# Broadcast(y_ref.shape[:axis], x.shape[:-1]) +
# x.shape[-1:] + y_ref.shape[axis+1:]
expand_x_fn = _make_expand_x_fn_for_non_batch_interpolation(y_ref, axis)
t = expand_x_fn(t)
nan_idx = expand_x_fn(nan_idx, broadcast=True)
x_idx_unclipped = expand_x_fn(x_idx_unclipped, broadcast=True)
y = t * y_ref_above + (1 - t) * y_ref_below
# Now begins a long excursion to fill values outside [x_min, x_max].
# Re-insert NaN wherever x was NaN.
y = tf.where(nan_idx, tf.constant(np.nan, y.dtype), y)
if not need_separate_fills:
if fill_value == 'constant_extension':
pass # Already handled by clipping x_idx_unclipped.
else:
y = tf.where(
(x_idx_unclipped < 0) | (x_idx_unclipped > ny - 1),
fill_value, y)
else:
# Fill values below x_ref_min <==> x_idx_unclipped < 0.
if fill_value_below == 'constant_extension':
pass # Already handled by the clipping that created x_idx_unclipped.
elif fill_value_below == 'extrapolate':
if batch_y_ref:
# For every batch member, gather the first two elements of y across
# `axis`.
y_0 = tf.gather(y_ref, [0], axis=axis)
y_1 = tf.gather(y_ref, [1], axis=axis)
else:
# If not batching, we want to gather the first two elements, just like
# above. However, these results need to be replicated for every
# member of x. An easy way to do that is to gather using
# indices = zeros/ones(x.shape).
y_0 = tf.gather(
y_ref, tf.zeros(tf.shape(x), dtype=tf.int32), axis=axis)
y_1 = tf.gather(
y_ref, tf.ones(tf.shape(x), dtype=tf.int32), axis=axis)
x_delta = (x_ref_max - x_ref_min) / (ny - 1)
x_factor = expand_x_fn((x - x_ref_min) / x_delta, broadcast=True)
y = tf.where(x_idx_unclipped < 0, y_0 + x_factor * (y_1 - y_0), y)
else:
y = tf.where(x_idx_unclipped < 0, fill_value_below, y)
# Fill values above x_ref_min <==> x_idx_unclipped > ny - 1.
if fill_value_above == 'constant_extension':
pass # Already handled by the clipping that created x_idx_unclipped.
elif fill_value_above == 'extrapolate':
ny_int32 = tf.shape(y_ref)[axis]
if batch_y_ref:
y_n1 = tf.gather(y_ref, [tf.shape(y_ref)[axis] - 1], axis=axis)
y_n2 = tf.gather(y_ref, [tf.shape(y_ref)[axis] - 2], axis=axis)
else:
y_n1 = tf.gather(
y_ref, tf.fill(tf.shape(x), ny_int32 - 1), axis=axis)
y_n2 = tf.gather(
y_ref, tf.fill(tf.shape(x), ny_int32 - 2), axis=axis)
x_delta = (x_ref_max - x_ref_min) / (ny - 1)
x_factor = expand_x_fn((x - x_ref_max) / x_delta, broadcast=True)
y = tf.where(x_idx_unclipped > ny - 1,
y_n1 + x_factor * (y_n1 - y_n2), y)
else:
y = tf.where(x_idx_unclipped > ny - 1, fill_value_above, y)
return y