def _interp_regular_1d_grid_impl()

in tensorflow_probability/python/math/interpolation.py [0:0]


def _interp_regular_1d_grid_impl(x,
                                 x_ref_min,
                                 x_ref_max,
                                 y_ref,
                                 axis=-1,
                                 batch_y_ref=False,
                                 fill_value='constant_extension',
                                 fill_value_below=None,
                                 fill_value_above=None,
                                 grid_regularizing_transform=None,
                                 name=None):
  """1-D interpolation that works with/without batching."""
  # Note: we do *not* make the no-batch version a special case of the batch
  # version, because that would an inefficient use of batch_gather with
  # unnecessarily broadcast args.
  with tf.name_scope(name or 'interp_regular_1d_grid_impl'):

    # Arg checking.
    allowed_fv_st = ('constant_extension', 'extrapolate')
    for fv in (fill_value, fill_value_below, fill_value_above):
      if isinstance(fv, str) and fv not in allowed_fv_st:
        raise ValueError(
            'A fill value ({}) was not an allowed string ({})'.format(
                fv, allowed_fv_st))

    # Separate value fills for below/above incurs extra cost, so keep track of
    # whether this is needed.
    need_separate_fills = (
        fill_value_above is not None or fill_value_below is not None or
        fill_value == 'extrapolate'  # always requries separate below/above
    )
    if need_separate_fills and fill_value_above is None:
      fill_value_above = fill_value
    if need_separate_fills and fill_value_below is None:
      fill_value_below = fill_value

    dtype = dtype_util.common_dtype([x, x_ref_min, x_ref_max, y_ref],
                                    dtype_hint=tf.float32)
    x = tf.convert_to_tensor(x, name='x', dtype=dtype)

    x_ref_min = tf.convert_to_tensor(
        x_ref_min, name='x_ref_min', dtype=dtype)
    x_ref_max = tf.convert_to_tensor(
        x_ref_max, name='x_ref_max', dtype=dtype)
    if not batch_y_ref:
      _assert_ndims_statically(x_ref_min, expect_ndims=0)
      _assert_ndims_statically(x_ref_max, expect_ndims=0)

    y_ref = tf.convert_to_tensor(y_ref, name='y_ref', dtype=dtype)

    if batch_y_ref:
      # If we're batching,
      #   x.shape ~ [A1,...,AN, D],  x_ref_min/max.shape ~ [A1,...,AN]
      # So to add together we'll append a singleton.
      # If not batching, x_ref_min/max are scalar, so this isn't an issue,
      # moreover, if not batching, x can be scalar, and expanding x_ref_min/max
      # would cause a bad expansion of x when added to x (confused yet?).
      x_ref_min = x_ref_min[..., tf.newaxis]
      x_ref_max = x_ref_max[..., tf.newaxis]

    axis = ps.convert_to_shape_tensor(axis, name='axis', dtype=tf.int32)
    axis = ps.non_negative_axis(axis, ps.rank(y_ref))
    _assert_ndims_statically(axis, expect_ndims=0)

    ny = tf.cast(tf.shape(y_ref)[axis], dtype)

    # Map [x_ref_min, x_ref_max] to [0, ny - 1].
    # This is the (fractional) index of x.
    if grid_regularizing_transform is None:
      g = lambda x: x
    else:
      g = grid_regularizing_transform
    fractional_idx = ((g(x) - g(x_ref_min)) / (g(x_ref_max) - g(x_ref_min)))
    x_idx_unclipped = fractional_idx * (ny - 1)

    # Wherever x is NaN, x_idx_unclipped will be NaN as well.
    # Keep track of the nan indices here (so we can impute NaN later).
    # Also eliminate any NaN indices, since there is not NaN in 32bit.
    nan_idx = tf.math.is_nan(x_idx_unclipped)
    zero = tf.zeros((), dtype=dtype)
    x_idx_unclipped = tf.where(nan_idx, zero, x_idx_unclipped)
    x_idx = tf.clip_by_value(x_idx_unclipped, zero, ny - 1)

    # Get the index above and below x_idx.
    # Naively we could set idx_below = floor(x_idx), idx_above = ceil(x_idx),
    # however, this results in idx_below == idx_above whenever x is on a grid.
    # This in turn results in y_ref_below == y_ref_above, and then the gradient
    # at this point is zero.  So here we 'jitter' one of idx_below, idx_above,
    # so that they are at different values.  This jittering does not affect the
    # interpolated value, but does make the gradient nonzero (unless of course
    # the y_ref values are the same).
    idx_below = tf.floor(x_idx)
    idx_above = tf.minimum(idx_below + 1, ny - 1)
    idx_below = tf.maximum(idx_above - 1, 0)

    # These are the values of y_ref corresponding to above/below indices.
    idx_below_int32 = tf.cast(idx_below, dtype=tf.int32)
    idx_above_int32 = tf.cast(idx_above, dtype=tf.int32)
    if batch_y_ref:
      # If y_ref.shape ~ [A1,...,AN, C, B1,...,BN],
      # and x.shape, x_ref_min/max.shape ~ [A1,...,AN, D]
      # Then y_ref_below.shape ~ [A1,...,AN, D, B1,...,BN]
      y_ref_below = _batch_gather_with_broadcast(y_ref, idx_below_int32, axis)
      y_ref_above = _batch_gather_with_broadcast(y_ref, idx_above_int32, axis)
    else:
      # Here, y_ref_below.shape =
      #   y_ref.shape[:axis] + x.shape + y_ref.shape[axis + 1:]
      y_ref_below = tf.gather(y_ref, idx_below_int32, axis=axis)
      y_ref_above = tf.gather(y_ref, idx_above_int32, axis=axis)

    # Use t to get a convex combination of the below/above values.
    t = x_idx - idx_below

    # x, and tensors shaped like x, need to be added to, and selected with
    # (using tf.where) the output y.  This requires appending singletons.
    # Make functions appropriate for batch/no-batch.
    if batch_y_ref:
      # In the non-batch case, the output shape is going to be
      #   y_ref.shape[:axis] + x.shape + y_ref.shape[axis+1:]
      expand_x_fn = _make_expand_x_fn_for_batch_interpolation(y_ref, axis)
    else:
      # In the batch case, the output shape is going to be
      #   Broadcast(y_ref.shape[:axis], x.shape[:-1]) +
      #   x.shape[-1:] +  y_ref.shape[axis+1:]
      expand_x_fn = _make_expand_x_fn_for_non_batch_interpolation(y_ref, axis)

    t = expand_x_fn(t)
    nan_idx = expand_x_fn(nan_idx, broadcast=True)
    x_idx_unclipped = expand_x_fn(x_idx_unclipped, broadcast=True)

    y = t * y_ref_above + (1 - t) * y_ref_below

    # Now begins a long excursion to fill values outside [x_min, x_max].

    # Re-insert NaN wherever x was NaN.
    y = tf.where(nan_idx, tf.constant(np.nan, y.dtype), y)

    if not need_separate_fills:
      if fill_value == 'constant_extension':
        pass  # Already handled by clipping x_idx_unclipped.
      else:
        y = tf.where(
            (x_idx_unclipped < 0) | (x_idx_unclipped > ny - 1),
            fill_value, y)
    else:
      # Fill values below x_ref_min <==> x_idx_unclipped < 0.
      if fill_value_below == 'constant_extension':
        pass  # Already handled by the clipping that created x_idx_unclipped.
      elif fill_value_below == 'extrapolate':
        if batch_y_ref:
          # For every batch member, gather the first two elements of y across
          # `axis`.
          y_0 = tf.gather(y_ref, [0], axis=axis)
          y_1 = tf.gather(y_ref, [1], axis=axis)
        else:
          # If not batching, we want to gather the first two elements, just like
          # above.  However, these results need to be replicated for every
          # member of x.  An easy way to do that is to gather using
          # indices = zeros/ones(x.shape).
          y_0 = tf.gather(
              y_ref, tf.zeros(tf.shape(x), dtype=tf.int32), axis=axis)
          y_1 = tf.gather(
              y_ref, tf.ones(tf.shape(x), dtype=tf.int32), axis=axis)
        x_delta = (x_ref_max - x_ref_min) / (ny - 1)
        x_factor = expand_x_fn((x - x_ref_min) / x_delta, broadcast=True)
        y = tf.where(x_idx_unclipped < 0, y_0 + x_factor * (y_1 - y_0), y)
      else:
        y = tf.where(x_idx_unclipped < 0, fill_value_below, y)
      # Fill values above x_ref_min <==> x_idx_unclipped > ny - 1.
      if fill_value_above == 'constant_extension':
        pass  # Already handled by the clipping that created x_idx_unclipped.
      elif fill_value_above == 'extrapolate':
        ny_int32 = tf.shape(y_ref)[axis]
        if batch_y_ref:
          y_n1 = tf.gather(y_ref, [tf.shape(y_ref)[axis] - 1], axis=axis)
          y_n2 = tf.gather(y_ref, [tf.shape(y_ref)[axis] - 2], axis=axis)
        else:
          y_n1 = tf.gather(
              y_ref, tf.fill(tf.shape(x), ny_int32 - 1), axis=axis)
          y_n2 = tf.gather(
              y_ref, tf.fill(tf.shape(x), ny_int32 - 2), axis=axis)
        x_delta = (x_ref_max - x_ref_min) / (ny - 1)
        x_factor = expand_x_fn((x - x_ref_max) / x_delta, broadcast=True)
        y = tf.where(x_idx_unclipped > ny - 1,
                     y_n1 + x_factor * (y_n1 - y_n2), y)
      else:
        y = tf.where(x_idx_unclipped > ny - 1, fill_value_above, y)

    return y