in tensorflow_compression/python/layers/signal_conv.py [0:0]
def _up_convolve_transpose_valid(self, inputs, kernel, prepadding):
# Computes upsampling followed by convolution, via transpose convolution ops
# in VALID mode. This is a relatively inefficient implementation of
# upsampled convolutions, where we need to crop away a lot of the values
# computed in the boundaries.
# Transpose convolutions expect the output and input channels in reversed
# order. We implement this by swapping those dimensions of the kernel.
# For channel separable convolutions, we can't currently perform anything
# other than one filter per channel, so the last dimension needs to be of
# length one. Since this happens to be the format that the op expects it,
# we can skip the transpose in that case.
if not self.channel_separable:
kernel = tf.transpose(
kernel, list(range(self._rank)) + [self._rank + 1, self._rank])
# Compute shape of temporary.
input_shape = tf.shape(inputs)
temp_shape = [input_shape[0]] + (self._rank + 1) * [None]
if self.data_format == "channels_last":
spatial_axes = range(1, self._rank + 1)
temp_shape[-1] = (
input_shape[-1] if self.channel_separable else self.filters)
else:
spatial_axes = range(2, self._rank + 2)
temp_shape[1] = input_shape[1] if self.channel_separable else self.filters
if self.extra_pad_end:
get_length = lambda l, s, k: l * s + (k - 1)
else:
get_length = lambda l, s, k: l * s + ((k - 1) - (s - 1))
for i, a in enumerate(spatial_axes):
temp_shape[a] = get_length(
input_shape[a], self.strides_up[i], self.kernel_support[i])
data_format = self._op_data_format
strides = self._padded_tuple(self.strides_up, 1)
# Compute convolution.
if self._rank <= 3 and not self.channel_separable:
outputs = tf.nn.conv_transpose(
inputs, kernel, temp_shape,
strides=strides, padding="VALID", data_format=data_format)
elif self._rank == 1 and self.channel_separable and self.filters == 1:
# There's no 1D equivalent to `depthwise_conv2d_backprop_input`, so we
# insert an extra dimension and use the 2D op.
extradim = {"channels_first": 2, "channels_last": 1}[self.data_format]
data_format = data_format.replace("W", "HW")
strides = strides[:extradim] + (strides[extradim],) + strides[extradim:]
temp_shape = temp_shape[:extradim] + [1] + temp_shape[extradim:]
kernel = tf.expand_dims(kernel, 0)
inputs = tf.expand_dims(inputs, extradim)
outputs = tf.nn.depthwise_conv2d_backprop_input(
temp_shape, kernel, inputs,
strides=strides, padding="VALID", data_format=data_format)
outputs = tf.squeeze(outputs, [extradim])
elif (self._rank == 2 and self.channel_separable and
self.filters == 1 and self.strides_up[0] == self.strides_up[1]):
outputs = tf.nn.depthwise_conv2d_backprop_input(
temp_shape, kernel, inputs,
strides=strides, padding="VALID", data_format=data_format)
else:
self._raise_notimplemented()
# Perform crop, taking into account any pre-padding that was applied.
slices = (self._rank + 2) * [slice(None)]
for i, a in enumerate(spatial_axes):
if self.padding == "valid":
# Take `kernel_support - 1` samples away from both sides. This leaves
# just samples computed without any padding.
start = stop = self.kernel_support[i] - 1
else: # same
# Take half of kernel sizes plus the pre-padding away from each side.
start = prepadding[i][0] * self.strides_up[i]
start += self.kernel_support[i] // 2
stop = prepadding[i][1] * self.strides_up[i]
stop += (self.kernel_support[i] - 1) // 2
step = self.strides_down[i]
start = start if start > 0 else None
stop = -stop if stop > 0 else None
step = step if step > 1 else None
slices[a] = slice(start, stop, step)
if not all(s.start is s.stop is s.step is None for s in slices):
outputs = outputs[tuple(slices)]
return outputs