def make_convolution_transpose_fn_with_subkernels_matrix()

in tensorflow_probability/python/experimental/nn/util/convolution_util.py [0:0]


def make_convolution_transpose_fn_with_subkernels_matrix(
    filter_shape, strides, padding, rank=2, dilations=None, dtype=tf.int32,
    validate_args=False, name=None):
  """Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`."""
  with tf.name_scope(name or 'make_convolution_transpose_fn_with_dilation'):

    if tf.get_static_value(rank) != 2:
      raise NotImplementedError('Argument `rank` currently only supports `2`; '
                                'saw "{}".'.format(rank))

    strides = tf.get_static_value(strides)
    if not isinstance(strides, int):
      raise ValueError('Argument `strides` must be a statically known integer.'
                       'Saw: {}'.format(strides))

    [
        filter_shape,
        rank,
        _,
        padding,
        dilations,
    ] = prepare_conv_args(
        filter_shape, rank=rank, strides=strides, padding=padding,
        dilations=dilations, is_transpose=True, validate_args=validate_args)

    fh, fw = filter_shape
    dh, dw = dilations

    # Determine maximum filter height and filter width of sub-kernels.
    sub_fh = (fh - 1) // strides + 1
    sub_fw = (fw - 1) // strides + 1

    def loop_body(i_, event_ind):
      i = i_ // strides
      j = i_ % strides

      i_ind = ps.range(
          i * fw, ps.maximum(i, fh) * fw, delta=strides * fw, dtype=dtype)
      j_ind = ps.range(j, ps.maximum(j, fw), delta=strides, dtype=dtype)

      nc = cartesian_add([i_ind, j_ind])
      ind = ps.reverse(ps.reshape(nc, shape=[-1]), axis=[0])

      k = ps.reshape(
          cartesian_add(
              [ps.range(ps.shape(nc)[0] * sub_fw, delta=sub_fw, dtype=dtype),
               ps.range(ps.shape(nc)[1], dtype=dtype)]),
          shape=[-1])
      last_j = strides - (fw - j - 1) % strides - 1
      last_i = strides - (fh - i - 1) % strides - 1
      kernel_ind = ps.stack(
          [k, ps.ones_like(k) * last_i * strides + last_j], axis=1)
      event_ind = ps.tensor_scatter_nd_update(
          event_ind, ind[..., tf.newaxis], kernel_ind)

      return i_ + 1, event_ind

    event_ind = ps.zeros((fh * fw, 2), dtype=dtype)
    _, event_ind = tf.while_loop(
        lambda i, _: i < strides ** 2,
        loop_body,
        [tf.zeros([], dtype=dtype), event_ind])

    tot_pad_top, tot_pad_bottom = _get_transpose_conv_dilated_padding(
        fh, stride=strides, dilation=dh, padding=padding)
    tot_pad_left, tot_pad_right = _get_transpose_conv_dilated_padding(
        fw, stride=strides, dilation=dw, padding=padding)

    pad_bottom = (tot_pad_bottom - 1) // strides + 1
    pad_top = (tot_pad_top - 1) // strides + 1
    pad_right = (tot_pad_right - 1) // strides + 1
    pad_left = (tot_pad_left - 1) // strides + 1
    padding_vals = ((pad_top, pad_bottom), (pad_left, pad_right))

    truncate_top = pad_top * strides - tot_pad_top
    truncate_left = pad_left * strides - tot_pad_left

    def op(x, kernel):
      input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32)
      x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
      kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel')

      batch_shape, event_shape = ps.split(
          ps.shape(x), num_or_size_splits=[-1, 3])
      xh, xw, c_in = ps.unstack(event_shape, num=3)

      kernel_shape = ps.shape(kernel)
      c_out = kernel_shape[-1]
      kernel_batch = kernel_shape[:-2]
      assertions = _maybe_validate_input_shapes(
          kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw,
          validate_args=validate_args)

      with tf.control_dependencies(assertions):

        # If the kernel does not have batch shape, fall back to
        # `conv2d_transpose` (unless dilations > 1, which is not implemented in
        # `conv2d_transpose`).
        if (tf.get_static_value(ps.rank(kernel)) == 2
            and all(d == 1 for d in dilations)):
          return _call_conv2d_transpose(
              x, kernel=kernel, filter_shape=filter_shape,
              strides=(strides,) * rank, padding=padding, dilations=dilations,
              c_out=c_out, batch_shape=batch_shape, event_shape=event_shape)

        n = ps.maximum(0, ps.rank(x) - 3)
        paddings = ps.pad(
            padding_vals,
            paddings=[[n, 1], [0, 0]],
            constant_values=0)

        x_pad = tf.pad(x, paddings=paddings, constant_values=0)
        x_pad_shape = ps.shape(x_pad)[:-3]
        flat_shape = ps.pad(x_pad_shape, paddings=[[0, 1]], constant_values=-1)
        flat_x = tf.reshape(x_pad, shape=flat_shape)

        idx, s = im2row_index(
            (xh + tf.reduce_sum(padding_vals[0]),
             xw + tf.reduce_sum(padding_vals[1]), c_in),
            block_shape=(sub_fh, sub_fw), slice_step=(1, 1), dilations=dilations
            )

        x_ = tf.gather(flat_x, indices=idx, axis=-1)
        im_x = tf.reshape(x_, shape=ps.concat([x_pad_shape, s], axis=0))

        # Add channels to subkernel indices
        idx_event = event_ind * [[c_in, 1]]
        idx_event_channels = (
            idx_event[tf.newaxis]
            + tf.stack([ps.range(c_in), tf.zeros((c_in,), dtype=dtype)],
                       axis=-1)[:, tf.newaxis, :])
        idx_event = tf.squeeze(
            tf.batch_to_space(
                idx_event_channels, block_shape=[c_in], crops=[[0, 0]]), axis=0)
        idx_event_broadcast = tf.broadcast_to(
            idx_event,
            shape=ps.concat([kernel_batch, ps.shape(idx_event)], axis=0))

        # Add cartesian product of batch indices, since scatter_nd can only be
        # applied to leading dimensions.
        idx_batch = tf.stack(
            tf.meshgrid(
                *[ps.range(b_, delta=1, dtype=dtype)
                  for b_ in tf.unstack(kernel_batch)], indexing='ij'),
            axis=ps.size(kernel_batch))

        idx_batch = tf.cast(idx_batch, dtype=dtype)  # empty tensor is float

        idx_batch_broadcast = idx_batch[..., tf.newaxis, :] + tf.zeros(
            (ps.shape(idx_event)[0], 1), dtype=dtype)
        idx_kernel = tf.concat(
            [idx_batch_broadcast, idx_event_broadcast], axis=-1)

        kernel_mat = tf.scatter_nd(
            idx_kernel,
            updates=kernel,
            shape=ps.cast(
                ps.concat([kernel_batch,
                           [sub_fh * sub_fw * c_in, strides ** 2, c_out]],
                          axis=0),
                dtype=dtype))

        kernel_mat = tf.reshape(
            kernel_mat,
            shape=ps.concat(
                [ps.shape(kernel_mat)[:-2], [strides ** 2 * c_out]], axis=0))

        kernel_mat = kernel_mat[..., tf.newaxis, :, :]
        out = tf.matmul(im_x, kernel_mat)
        broadcast_batch_shape = ps.broadcast_shape(batch_shape, kernel_batch)

        if strides > 1:
          tot_size = tf.reduce_prod(broadcast_batch_shape)
          flat_out = tf.reshape(
              out,
              shape=ps.concat([[tot_size], ps.shape(out)[-3:]], axis=0))
          out = tf.nn.depth_to_space(flat_out, block_size=strides)

        out_height = _deconv_output_length(
            xh, filter_size=fh, padding=padding, output_padding=None,
            stride=strides, dilation=dh)
        out_width = _deconv_output_length(
            xw, filter_size=fw, padding=padding, output_padding=None,
            stride=strides, dilation=dw)

        out = out[..., truncate_top:truncate_top + out_height,
                  truncate_left:truncate_left + out_width, :]
        out = tf.reshape(
            out, shape=ps.concat(
                [broadcast_batch_shape, [out_height, out_width, c_out]],
                axis=0))
        return out
    return op