in tensorflow_compression/python/layers/signal_conv.py [0:0]
def _correlate_down_valid(self, inputs, kernel):
# Computes valid correlation followed by downsampling.
data_format = self._op_data_format
strides = self._padded_tuple(self.strides_down, 1)
if self._rank <= 3 and not self.channel_separable:
outputs = tf.nn.convolution(
inputs, kernel,
strides=self.strides_down, padding="VALID", data_format=data_format)
elif self._rank == 1 and self.channel_separable:
# There is no 1D depthwise correlation op, so if that is requested we
# insert an extra dimension and use the 2D op.
extradim = {"channels_first": 2, "channels_last": 1}[self.data_format]
strides = strides[:extradim] + (strides[extradim],) + strides[extradim:]
data_format = data_format.replace("W", "HW")
inputs = tf.expand_dims(inputs, extradim)
kernel = tf.expand_dims(kernel, 0)
outputs = tf.nn.depthwise_conv2d(
inputs, kernel,
strides=strides, padding="VALID", data_format=data_format)
outputs = tf.squeeze(outputs, [extradim])
elif self._rank == 2 and self.channel_separable:
# `tf.nn.depthwise_conv2d` performs channel-separable correlations
# followed by optional downsampling. All strides must be identical. If
# not, we downsample by the greatest common factor and then downsample
# the result further.
gcf = _greatest_common_factor(self.strides_down)
strides = self._padded_tuple(self._rank * (gcf,), 1)
outputs = tf.nn.depthwise_conv2d(
inputs, kernel,
strides=strides, padding="VALID", data_format=data_format)
# Perform remaining downsampling.
slices = tuple(slice(None, None, s // gcf) for s in self.strides_down)
if any(s.step > 1 for s in slices):
outputs = outputs[self._padded_tuple(slices, slice(None))]
else:
self._raise_notimplemented()
return outputs