in tensorflow_probability/python/experimental/nn/util/convolution_util.py [0:0]
def make_convolution_transpose_fn_with_subkernels(
filter_shape, strides, padding, rank=2, dilations=None, dtype=tf.int32,
validate_args=False, name=None):
"""Like `tf.nn.conv2d` except applies batch of kernels to batch of `x`."""
with tf.name_scope(name or 'make_convolution_transpose_fn_with_dilation'):
# Enable v2 control flow to avoid None gradients through TensorArray.
tf.compat.v1.enable_control_flow_v2()
if tf.get_static_value(rank) != 2:
raise NotImplementedError('Argument `rank` currently only supports `2`; '
'saw "{}".'.format(rank))
[
filter_shape,
rank,
strides,
padding,
dilations,
] = prepare_conv_args(
filter_shape, rank=rank, strides=strides, padding=padding,
dilations=dilations, is_transpose=True, validate_args=validate_args)
sh, sw = strides
fh, fw = filter_shape
dh, dw = dilations
# Determine maximum filter height and filter width of sub-kernels.
sub_fh = (fh - 1) // sh + 1
sub_fw = (fw - 1) // sw + 1
def loop_body(i_, kernels_ind):
i = i_ // sw
j = i_ % sw
i_ind = ps.range(
i * fw, ps.maximum(i, fh) * fw, delta=sh * fw, dtype=dtype)
j_ind = ps.range(j, ps.maximum(j, fw), delta=sw, dtype=dtype)
last_j = sw - (fw - j - 1) % sw - 1
last_i = sh - (fh - i - 1) % sh - 1
pos = last_i * sw + last_j
nc = cartesian_add([i_ind, j_ind])
kernels_ind = kernels_ind.write(
pos, ps.reverse(ps.reverse(nc, [0]), [1]))
return i_ + 1, kernels_ind
kernels_ind = tf.TensorArray(dtype=dtype, infer_shape=False, size=sh * sw)
_, kernels_ind = tf.while_loop(
lambda i, _: i < sh * sw,
loop_body,
[0, kernels_ind])
tot_pad_top, tot_pad_bottom = _get_transpose_conv_dilated_padding(
fh, stride=sh, dilation=dh, padding=padding)
tot_pad_left, tot_pad_right = _get_transpose_conv_dilated_padding(
fw, stride=sw, dilation=dw, padding=padding)
pad_bottom = (tot_pad_bottom - 1) // sh + 1
pad_top = (tot_pad_top - 1) // sh + 1
pad_right = (tot_pad_right - 1) // sw + 1
pad_left = (tot_pad_left - 1) // sw + 1
padding_vals = ((pad_top, pad_bottom), (pad_left, pad_right))
truncate_top = pad_top * sh - tot_pad_top
truncate_left = pad_left * sw - tot_pad_left
def op(x, kernel):
input_dtype = dtype_util.common_dtype([x, kernel], dtype_hint=tf.float32)
x = tf.convert_to_tensor(x, dtype=input_dtype, name='x')
kernel = tf.convert_to_tensor(kernel, dtype=input_dtype, name='kernel')
batch_shape, event_shape = ps.split(
ps.shape(x), num_or_size_splits=[-1, 3])
xh, xw, c_in = ps.unstack(event_shape, num=3)
kernel_shape = ps.shape(kernel)
c_out = kernel_shape[-1]
kernel_batch = kernel_shape[:-2]
assertions = _maybe_validate_input_shapes(
kernel_shape, channels_in=c_in, filter_height=fh, filter_width=fw,
validate_args=validate_args)
with tf.control_dependencies(assertions):
# If the kernel does not have batch shape, fall back to
# `conv2d_transpose` (unless dilations > 1, which is not implemented in
# `conv2d_transpose`).
if (tf.get_static_value(ps.rank(kernel)) == 2
and all(d == 1 for d in dilations)):
return _call_conv2d_transpose(
x, kernel, filter_shape, strides, padding, dilations, c_out,
batch_shape, event_shape)
n = ps.maximum(0, ps.rank(x) - 3)
paddings = ps.pad(
padding_vals,
paddings=[[n, 1], [0, 0]],
constant_values=0)
x_pad = tf.pad(x, paddings=paddings, constant_values=0)
ex_h = xh + tf.reduce_sum(padding_vals[0]) - sub_fh + 1
ex_w = xw + tf.reduce_sum(padding_vals[1]) - sub_fw + 1
def loop_body(i, outputs):
subkernel_ind = kernels_ind.read(i)
fh_, fw_ = ps.unstack(ps.shape(subkernel_ind), num=2)
eh = ex_h + fh_ - 1
ew = ex_w + fw_ - 1
subkernel_ind = ps.reshape(
ps.reshape(subkernel_ind * c_in, shape=[-1])[:, tf.newaxis]
+ ps.range(c_in), shape=[-1])
k = tf.gather(kernel, subkernel_ind, axis=-2)
ind, shape = im2row_index(
[eh, ew, c_in],
block_shape=(fh_, fw_),
slice_step=(1, 1),
dilations=dilations)
x_i = x_pad[..., :eh, :ew, :]
x_i_shape = ps.shape(x_i)
flat_shape = ps.pad(
x_i_shape[:-3], paddings=[[0, 1]], constant_values=-1)
flat_x = tf.reshape(x_i, flat_shape)
x_ = tf.gather(flat_x, ind, axis=-1)
im_x = tf.reshape(x_, ps.concat([x_i_shape[:-3], shape], axis=0))
outputs = outputs.write(
i,
tf.matmul(
im_x,
tf.reshape(
k, ps.concat(
[kernel_batch, [1, fh_ * fw_* c_in, c_out]], axis=0)))
)
return i + 1, outputs
outputs = tf.TensorArray(dtype=input_dtype, size=sh * sw)
_, outputs = tf.while_loop(
lambda i, _: i < sh * sw,
loop_body,
[0, outputs])
y = outputs.concat()
m = tf.reduce_prod(ps.shape(y)[:-3])
y_ = tf.reshape(y, shape=ps.concat([[m], ps.shape(y)[-3:]], axis=0))
y2 = tf.batch_to_space(
y_, strides, crops=tf.zeros([2, 2], dtype=tf.int64))
broadcast_batch_shape = ps.broadcast_shape(batch_shape, kernel_batch)
y2 = tf.reshape(y2, ps.concat(
[broadcast_batch_shape, ps.shape(y2)[-3:]], axis=0))
out_height = _deconv_output_length(
xh, filter_size=fh, padding=padding, output_padding=None,
stride=sh, dilation=dh)
out_width = _deconv_output_length(
xw, filter_size=fw, padding=padding, output_padding=None,
stride=sw, dilation=dw)
return y2[..., truncate_top:truncate_top+out_height,
truncate_left:truncate_left+out_width, :]
return op