in MTRF/algorithms/softlearning/models/convnet.py [0:0]
def convnet_transpose_model(
output_shape=(32, 32, 3),
conv_filters=(64, 64, 64),
conv_kernel_sizes=(3, 3, 3),
conv_strides=(2, 2, 2),
padding="SAME",
normalization_type=None,
normalization_kwargs={},
downsampling_type='conv',
activation=layers.LeakyReLU,
output_activation='tanh',
kernel_regularizer=None,
name='convnet_transpose',
*args,
**kwargs):
normalization_layer = NORMALIZATION_TYPES[normalization_type]
# kernel_regularizer = REGULARIZERS[kernel_regularizer_type]
def conv_transpose_block(n_filters,
kernel_size,
stride,
block_activation,
name='conv_transpose_block'):
conv_stride = stride if downsampling_type == 'conv' else 1
block_parts = [
tfkl.Conv2DTranspose(
filters=n_filters,
kernel_size=kernel_size,
padding=padding,
strides=conv_stride,
activation='linear',
kernel_regularizer=kernel_regularizer,
)
]
if normalization_layer is not None:
block_parts += [normalization_layer(**normalization_kwargs)]
block_parts += [(layers.Activation(block_activation, name=block_activation)
if isinstance(block_activation, str)
else block_activation())]
if downsampling_type in POOLING_TYPES:
block_parts += [
POOLING_TYPES[downsampling_type](
pool_size=stride, strides=stride
)
]
block = tfk.Sequential(block_parts, name=name)
return block
assert len(output_shape) == 3, 'Output shape needs to be (w, h, c), w = h'
w, h, c = output_shape
# TODO: generalize this to diffenent padding types (only works for
# SAME right now) as well as different sized images (this only really
# works for nice powers of stride length), i.e 32 -> 16 -> 8 -> 4
if padding != 'SAME':
raise NotImplementedError
base_w = w // np.product(conv_strides)
base_h = h // np.product(conv_strides)
base_shape = (base_w, base_h, conv_filters[0])
model = PicklableSequential([
tfkl.Dense(
units=np.product(base_shape),
# activation=(layers.Activation(activation)
# if isinstance(activation, str)
# else activation()),
kernel_regularizer=kernel_regularizer
),
(layers.Activation(activation)
if isinstance(activation, str)
else activation()),
tfkl.Reshape(target_shape=base_shape),
*[
conv_transpose_block(
conv_filter,
conv_kernel_size,
conv_stride,
activation,
name=f'conv_transpose_block_{i}')
for i, (conv_filter, conv_kernel_size, conv_stride) in
enumerate(zip(conv_filters, conv_kernel_sizes, conv_strides))
],
conv_transpose_block(
n_filters=c,
kernel_size=conv_kernel_sizes[-1],
stride=1,
block_activation=output_activation,
name=f'conv_transpose_block_output',
),
], name=name)
return model