def _up_convolve_transpose_valid()

in tensorflow_compression/python/layers/signal_conv.py [0:0]


  def _up_convolve_transpose_valid(self, inputs, kernel, prepadding):
    # Computes upsampling followed by convolution, via transpose convolution ops
    # in VALID mode. This is a relatively inefficient implementation of
    # upsampled convolutions, where we need to crop away a lot of the values
    # computed in the boundaries.

    # Transpose convolutions expect the output and input channels in reversed
    # order. We implement this by swapping those dimensions of the kernel.
    # For channel separable convolutions, we can't currently perform anything
    # other than one filter per channel, so the last dimension needs to be of
    # length one. Since this happens to be the format that the op expects it,
    # we can skip the transpose in that case.
    if not self.channel_separable:
      kernel = tf.transpose(
          kernel, list(range(self._rank)) + [self._rank + 1, self._rank])

    # Compute shape of temporary.
    input_shape = tf.shape(inputs)
    temp_shape = [input_shape[0]] + (self._rank + 1) * [None]
    if self.data_format == "channels_last":
      spatial_axes = range(1, self._rank + 1)
      temp_shape[-1] = (
          input_shape[-1] if self.channel_separable else self.filters)
    else:
      spatial_axes = range(2, self._rank + 2)
      temp_shape[1] = input_shape[1] if self.channel_separable else self.filters
    if self.extra_pad_end:
      get_length = lambda l, s, k: l * s + (k - 1)
    else:
      get_length = lambda l, s, k: l * s + ((k - 1) - (s - 1))
    for i, a in enumerate(spatial_axes):
      temp_shape[a] = get_length(
          input_shape[a], self.strides_up[i], self.kernel_support[i])

    data_format = self._op_data_format
    strides = self._padded_tuple(self.strides_up, 1)

    # Compute convolution.
    if self._rank <= 3 and not self.channel_separable:
      outputs = tf.nn.conv_transpose(
          inputs, kernel, temp_shape,
          strides=strides, padding="VALID", data_format=data_format)
    elif self._rank == 1 and self.channel_separable and self.filters == 1:
      # There's no 1D equivalent to `depthwise_conv2d_backprop_input`, so we
      # insert an extra dimension and use the 2D op.
      extradim = {"channels_first": 2, "channels_last": 1}[self.data_format]
      data_format = data_format.replace("W", "HW")
      strides = strides[:extradim] + (strides[extradim],) + strides[extradim:]
      temp_shape = temp_shape[:extradim] + [1] + temp_shape[extradim:]
      kernel = tf.expand_dims(kernel, 0)
      inputs = tf.expand_dims(inputs, extradim)
      outputs = tf.nn.depthwise_conv2d_backprop_input(
          temp_shape, kernel, inputs,
          strides=strides, padding="VALID", data_format=data_format)
      outputs = tf.squeeze(outputs, [extradim])
    elif (self._rank == 2 and self.channel_separable and
          self.filters == 1 and self.strides_up[0] == self.strides_up[1]):
      outputs = tf.nn.depthwise_conv2d_backprop_input(
          temp_shape, kernel, inputs,
          strides=strides, padding="VALID", data_format=data_format)
    else:
      self._raise_notimplemented()

    # Perform crop, taking into account any pre-padding that was applied.
    slices = (self._rank + 2) * [slice(None)]
    for i, a in enumerate(spatial_axes):
      if self.padding == "valid":
        # Take `kernel_support - 1` samples away from both sides. This leaves
        # just samples computed without any padding.
        start = stop = self.kernel_support[i] - 1
      else:  # same
        # Take half of kernel sizes plus the pre-padding away from each side.
        start = prepadding[i][0] * self.strides_up[i]
        start += self.kernel_support[i] // 2
        stop = prepadding[i][1] * self.strides_up[i]
        stop += (self.kernel_support[i] - 1) // 2
      step = self.strides_down[i]
      start = start if start > 0 else None
      stop = -stop if stop > 0 else None
      step = step if step > 1 else None
      slices[a] = slice(start, stop, step)
    if not all(s.start is s.stop is s.step is None for s in slices):
      outputs = outputs[tuple(slices)]

    return outputs