def make_convolution_transpose_fn_with_subkernels()

in tensorflow_probability/python/experimental/nn/util/convolution_util.py [0:0]


def make_convolution_transpose_fn_with_subkernels(
    filter_shape, strides, padding, rank=2, dilations=None, dtype=tf.int32,
    validate_args=False, name=None):
  """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`."""
  with tf.name_scope(name or 'make_convolution_transpose_fn_with_dilation'):

    # Enable v2 control flow to avoid None gradients through TensorArray.
    tf.compat.v1.enable_control_flow_v2()

    if tf.get_static_value(rank) != 2:
      raise NotImplementedError('Argument `rank` currently only supports `2`; '
                                'saw "{}".'.format(rank))
    [
        filter_shape,
        rank,
        strides,
        padding,
        dilations,
    ] = prepare_conv_args(
        filter_shape, rank=rank, strides=strides, padding=padding,
        dilations=dilations, is_transpose=True, validate_args=validate_args)

    sh, sw = strides
    fh, fw = filter_shape
    dh, dw = dilations

    # Determine maximum filter height and filter width of sub-kernels.
    sub_fh = (fh - 1) // sh + 1
    sub_fw = (fw - 1) // sw + 1

    def loop_body(i_, kernels_ind):
      i = i_ // sw
      j = i_ % sw

      i_ind = ps.range(
          i * fw, ps.maximum(i, fh) * fw, delta=sh * fw, dtype=dtype)
      j_ind = ps.range(j, ps.maximum(j, fw), delta=sw, dtype=dtype)

      last_j = sw - (fw - j - 1) % sw - 1
      last_i = sh - (fh - i - 1) % sh - 1
      pos = last_i * sw + last_j

      nc = cartesian_add([i_ind, j_ind])
      kernels_ind = kernels_ind.write(
          pos, ps.reverse(ps.reverse(nc, [0]), [1]))
      return i_ + 1, kernels_ind

    kernels_ind = tf.TensorArray(dtype=dtype, infer_shape=False, size=sh * sw)

    _, kernels_ind = tf.while_loop(
        lambda i, _: i < sh * sw,
        loop_body,
        [0, kernels_ind])

    tot_pad_top, tot_pad_bottom = _get_transpose_conv_dilated_padding(
        fh, stride=sh, dilation=dh, padding=padding)
    tot_pad_left, tot_pad_right = _get_transpose_conv_dilated_padding(
        fw, stride=sw, dilation=dw, padding=padding)

    pad_bottom = (tot_pad_bottom - 1) // sh + 1
    pad_top = (tot_pad_top - 1) // sh + 1
    pad_right = (tot_pad_right - 1) // sw + 1
    pad_left = (tot_pad_left - 1) // sw + 1
    padding_vals = ((pad_top, pad_bottom), (pad_left, pad_right))

    truncate_top = pad_top * sh - tot_pad_top
    truncate_left = pad_left * sw - tot_pad_left

    def op(x, kernel):
      input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32)
      x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
      kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel')

      batch_shape, event_shape = ps.split(
          ps.shape(x), num_or_size_splits=[-1, 3])
      xh, xw, c_in = ps.unstack(event_shape, num=3)

      kernel_shape = ps.shape(kernel)
      c_out = kernel_shape[-1]
      kernel_batch = kernel_shape[:-2]
      assertions = _maybe_validate_input_shapes(
          kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw,
          validate_args=validate_args)

      with tf.control_dependencies(assertions):
        # If the kernel does not have batch shape, fall back to
        # `conv2d_transpose` (unless dilations > 1, which is not implemented in
        # `conv2d_transpose`).
        if (tf.get_static_value(ps.rank(kernel)) == 2
            and all(d == 1 for d in dilations)):
          return _call_conv2d_transpose(
              x, kernel, filter_shape, strides, padding, dilations, c_out,
              batch_shape, event_shape)

        n = ps.maximum(0, ps.rank(x) - 3)
        paddings = ps.pad(
            padding_vals,
            paddings=[[n, 1], [0, 0]],
            constant_values=0)
        x_pad = tf.pad(x, paddings=paddings, constant_values=0)

        ex_h = xh + tf.reduce_sum(padding_vals[0]) - sub_fh + 1
        ex_w = xw + tf.reduce_sum(padding_vals[1]) - sub_fw + 1

        def loop_body(i, outputs):
          subkernel_ind = kernels_ind.read(i)
          fh_, fw_ = ps.unstack(ps.shape(subkernel_ind), num=2)
          eh = ex_h + fh_ - 1
          ew = ex_w + fw_ - 1

          subkernel_ind = ps.reshape(
              ps.reshape(subkernel_ind * c_in, shape=[-1])[:, tf.newaxis]
              + ps.range(c_in), shape=[-1])

          k = tf.gather(kernel, subkernel_ind, axis=-2)
          ind, shape = im2row_index(
              [eh, ew, c_in],
              block_shape=(fh_, fw_),
              slice_step=(1, 1),
              dilations=dilations)
          x_i = x_pad[..., :eh, :ew, :]
          x_i_shape = ps.shape(x_i)
          flat_shape = ps.pad(
              x_i_shape[:-3], paddings=[[0, 1]], constant_values=-1)
          flat_x = tf.reshape(x_i, flat_shape)
          x_ = tf.gather(flat_x, ind, axis=-1)
          im_x = tf.reshape(x_, ps.concat([x_i_shape[:-3], shape], axis=0))
          outputs = outputs.write(
              i,
              tf.matmul(
                  im_x,
                  tf.reshape(
                      k, ps.concat(
                          [kernel_batch, [1, fh_ * fw_* c_in, c_out]], axis=0)))
              )
          return i + 1, outputs

        outputs = tf.TensorArray(dtype=input_dtype, size=sh * sw)

        _, outputs = tf.while_loop(
            lambda i, _: i < sh * sw,
            loop_body,
            [0, outputs])

        y = outputs.concat()

        m = tf.reduce_prod(ps.shape(y)[:-3])
        y_ = tf.reshape(y, shape=ps.concat([[m], ps.shape(y)[-3:]], axis=0))
        y2 = tf.batch_to_space(
            y_, strides, crops=tf.zeros([2, 2], dtype=tf.int64))
        broadcast_batch_shape = ps.broadcast_shape(batch_shape, kernel_batch)
        y2 = tf.reshape(y2, ps.concat(
            [broadcast_batch_shape, ps.shape(y2)[-3:]], axis=0))

        out_height = _deconv_output_length(
            xh, filter_size=fh, padding=padding, output_padding=None,
            stride=sh, dilation=dh)
        out_width = _deconv_output_length(
            xw, filter_size=fw, padding=padding, output_padding=None,
            stride=sw, dilation=dw)

        return y2[..., truncate_top:truncate_top+out_height,
                  truncate_left:truncate_left+out_width, :]
    return op