in tensorflow_probability/python/experimental/nn/util/convolution_util.py [0:0]
def make_convolution_transpose_fn_with_subkernels_matrix(
filter_shape, strides, padding, rank=2, dilations=None, dtype=tf.int32,
validate_args=False, name=None):
"""Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`."""
with tf.name_scope(name or 'make_convolution_transpose_fn_with_dilation'):
if tf.get_static_value(rank) != 2:
raise NotImplementedError('Argument `rank` currently only supports `2`; '
'saw "{}".'.format(rank))
strides = tf.get_static_value(strides)
if not isinstance(strides, int):
raise ValueError('Argument `strides` must be a statically known integer.'
'Saw: {}'.format(strides))
[
filter_shape,
rank,
_,
padding,
dilations,
] = prepare_conv_args(
filter_shape, rank=rank, strides=strides, padding=padding,
dilations=dilations, is_transpose=True, validate_args=validate_args)
fh, fw = filter_shape
dh, dw = dilations
# Determine maximum filter height and filter width of sub-kernels.
sub_fh = (fh - 1) // strides + 1
sub_fw = (fw - 1) // strides + 1
def loop_body(i_, event_ind):
i = i_ // strides
j = i_ % strides
i_ind = ps.range(
i * fw, ps.maximum(i, fh) * fw, delta=strides * fw, dtype=dtype)
j_ind = ps.range(j, ps.maximum(j, fw), delta=strides, dtype=dtype)
nc = cartesian_add([i_ind, j_ind])
ind = ps.reverse(ps.reshape(nc, shape=[-1]), axis=[0])
k = ps.reshape(
cartesian_add(
[ps.range(ps.shape(nc)[0] * sub_fw, delta=sub_fw, dtype=dtype),
ps.range(ps.shape(nc)[1], dtype=dtype)]),
shape=[-1])
last_j = strides - (fw - j - 1) % strides - 1
last_i = strides - (fh - i - 1) % strides - 1
kernel_ind = ps.stack(
[k, ps.ones_like(k) * last_i * strides + last_j], axis=1)
event_ind = ps.tensor_scatter_nd_update(
event_ind, ind[..., tf.newaxis], kernel_ind)
return i_ + 1, event_ind
event_ind = ps.zeros((fh * fw, 2), dtype=dtype)
_, event_ind = tf.while_loop(
lambda i, _: i < strides ** 2,
loop_body,
[tf.zeros([], dtype=dtype), event_ind])
tot_pad_top, tot_pad_bottom = _get_transpose_conv_dilated_padding(
fh, stride=strides, dilation=dh, padding=padding)
tot_pad_left, tot_pad_right = _get_transpose_conv_dilated_padding(
fw, stride=strides, dilation=dw, padding=padding)
pad_bottom = (tot_pad_bottom - 1) // strides + 1
pad_top = (tot_pad_top - 1) // strides + 1
pad_right = (tot_pad_right - 1) // strides + 1
pad_left = (tot_pad_left - 1) // strides + 1
padding_vals = ((pad_top, pad_bottom), (pad_left, pad_right))
truncate_top = pad_top * strides - tot_pad_top
truncate_left = pad_left * strides - tot_pad_left
def op(x, kernel):
input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32)
x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel')
batch_shape, event_shape = ps.split(
ps.shape(x), num_or_size_splits=[-1, 3])
xh, xw, c_in = ps.unstack(event_shape, num=3)
kernel_shape = ps.shape(kernel)
c_out = kernel_shape[-1]
kernel_batch = kernel_shape[:-2]
assertions = _maybe_validate_input_shapes(
kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw,
validate_args=validate_args)
with tf.control_dependencies(assertions):
# If the kernel does not have batch shape, fall back to
# `conv2d_transpose` (unless dilations > 1, which is not implemented in
# `conv2d_transpose`).
if (tf.get_static_value(ps.rank(kernel)) == 2
and all(d == 1 for d in dilations)):
return _call_conv2d_transpose(
x, kernel=kernel, filter_shape=filter_shape,
strides=(strides,) * rank, padding=padding, dilations=dilations,
c_out=c_out, batch_shape=batch_shape, event_shape=event_shape)
n = ps.maximum(0, ps.rank(x) - 3)
paddings = ps.pad(
padding_vals,
paddings=[[n, 1], [0, 0]],
constant_values=0)
x_pad = tf.pad(x, paddings=paddings, constant_values=0)
x_pad_shape = ps.shape(x_pad)[:-3]
flat_shape = ps.pad(x_pad_shape, paddings=[[0, 1]], constant_values=-1)
flat_x = tf.reshape(x_pad, shape=flat_shape)
idx, s = im2row_index(
(xh + tf.reduce_sum(padding_vals[0]),
xw + tf.reduce_sum(padding_vals[1]), c_in),
block_shape=(sub_fh, sub_fw), slice_step=(1, 1), dilations=dilations
)
x_ = tf.gather(flat_x, indices=idx, axis=-1)
im_x = tf.reshape(x_, shape=ps.concat([x_pad_shape, s], axis=0))
# Add channels to subkernel indices
idx_event = event_ind * [[c_in, 1]]
idx_event_channels = (
idx_event[tf.newaxis]
+ tf.stack([ps.range(c_in), tf.zeros((c_in,), dtype=dtype)],
axis=-1)[:, tf.newaxis, :])
idx_event = tf.squeeze(
tf.batch_to_space(
idx_event_channels, block_shape=[c_in], crops=[[0, 0]]), axis=0)
idx_event_broadcast = tf.broadcast_to(
idx_event,
shape=ps.concat([kernel_batch, ps.shape(idx_event)], axis=0))
# Add cartesian product of batch indices, since scatter_nd can only be
# applied to leading dimensions.
idx_batch = tf.stack(
tf.meshgrid(
*[ps.range(b_, delta=1, dtype=dtype)
for b_ in tf.unstack(kernel_batch)], indexing='ij'),
axis=ps.size(kernel_batch))
idx_batch = tf.cast(idx_batch, dtype=dtype) # empty tensor is float
idx_batch_broadcast = idx_batch[..., tf.newaxis, :] + tf.zeros(
(ps.shape(idx_event)[0], 1), dtype=dtype)
idx_kernel = tf.concat(
[idx_batch_broadcast, idx_event_broadcast], axis=-1)
kernel_mat = tf.scatter_nd(
idx_kernel,
updates=kernel,
shape=ps.cast(
ps.concat([kernel_batch,
[sub_fh * sub_fw * c_in, strides ** 2, c_out]],
axis=0),
dtype=dtype))
kernel_mat = tf.reshape(
kernel_mat,
shape=ps.concat(
[ps.shape(kernel_mat)[:-2], [strides ** 2 * c_out]], axis=0))
kernel_mat = kernel_mat[..., tf.newaxis, :, :]
out = tf.matmul(im_x, kernel_mat)
broadcast_batch_shape = ps.broadcast_shape(batch_shape, kernel_batch)
if strides > 1:
tot_size = tf.reduce_prod(broadcast_batch_shape)
flat_out = tf.reshape(
out,
shape=ps.concat([[tot_size], ps.shape(out)[-3:]], axis=0))
out = tf.nn.depth_to_space(flat_out, block_size=strides)
out_height = _deconv_output_length(
xh, filter_size=fh, padding=padding, output_padding=None,
stride=strides, dilation=dh)
out_width = _deconv_output_length(
xw, filter_size=fw, padding=padding, output_padding=None,
stride=strides, dilation=dw)
out = out[..., truncate_top:truncate_top + out_height,
truncate_left:truncate_left + out_width, :]
out = tf.reshape(
out, shape=ps.concat(
[broadcast_batch_shape, [out_height, out_width, c_out]],
axis=0))
return out
return op