def _correlate_down_valid()

in tensorflow_compression/python/layers/signal_conv.py [0:0]


  def _correlate_down_valid(self, inputs, kernel):
    # Computes valid correlation followed by downsampling.
    data_format = self._op_data_format
    strides = self._padded_tuple(self.strides_down, 1)

    if self._rank <= 3 and not self.channel_separable:
      outputs = tf.nn.convolution(
          inputs, kernel,
          strides=self.strides_down, padding="VALID", data_format=data_format)
    elif self._rank == 1 and self.channel_separable:
      # There is no 1D depthwise correlation op, so if that is requested we
      # insert an extra dimension and use the 2D op.
      extradim = {"channels_first": 2, "channels_last": 1}[self.data_format]
      strides = strides[:extradim] + (strides[extradim],) + strides[extradim:]
      data_format = data_format.replace("W", "HW")
      inputs = tf.expand_dims(inputs, extradim)
      kernel = tf.expand_dims(kernel, 0)
      outputs = tf.nn.depthwise_conv2d(
          inputs, kernel,
          strides=strides, padding="VALID", data_format=data_format)
      outputs = tf.squeeze(outputs, [extradim])
    elif self._rank == 2 and self.channel_separable:
      # `tf.nn.depthwise_conv2d` performs channel-separable correlations
      # followed by optional downsampling. All strides must be identical. If
      # not, we downsample by the greatest common factor and then downsample
      # the result further.
      gcf = _greatest_common_factor(self.strides_down)
      strides = self._padded_tuple(self._rank * (gcf,), 1)
      outputs = tf.nn.depthwise_conv2d(
          inputs, kernel,
          strides=strides, padding="VALID", data_format=data_format)
      # Perform remaining downsampling.
      slices = tuple(slice(None, None, s // gcf) for s in self.strides_down)
      if any(s.step > 1 for s in slices):
        outputs = outputs[self._padded_tuple(slices, slice(None))]
    else:
      self._raise_notimplemented()

    return outputs