def convnet_transpose_model()

in MTRF/algorithms/softlearning/models/convnet.py [0:0]


def convnet_transpose_model(
        output_shape=(32, 32, 3),
        conv_filters=(64, 64, 64),
        conv_kernel_sizes=(3, 3, 3),
        conv_strides=(2, 2, 2),
        padding="SAME",
        normalization_type=None,
        normalization_kwargs={},
        downsampling_type='conv',
        activation=layers.LeakyReLU,
        output_activation='tanh',
        kernel_regularizer=None,
        name='convnet_transpose',
        *args,
        **kwargs):

    normalization_layer = NORMALIZATION_TYPES[normalization_type]
    # kernel_regularizer = REGULARIZERS[kernel_regularizer_type]

    def conv_transpose_block(n_filters,
                             kernel_size,
                             stride,
                             block_activation,
                             name='conv_transpose_block'):
        conv_stride = stride if downsampling_type == 'conv' else 1
        block_parts = [
            tfkl.Conv2DTranspose(
                filters=n_filters,
                kernel_size=kernel_size,
                padding=padding,
                strides=conv_stride,
                activation='linear',
                kernel_regularizer=kernel_regularizer,
            )
        ]
        if normalization_layer is not None:
            block_parts += [normalization_layer(**normalization_kwargs)]

        block_parts += [(layers.Activation(block_activation, name=block_activation)
                         if isinstance(block_activation, str)
                         else block_activation())]

        if downsampling_type in POOLING_TYPES:
            block_parts += [
                POOLING_TYPES[downsampling_type](
                    pool_size=stride, strides=stride
                )
            ]
        block = tfk.Sequential(block_parts, name=name)
        return block

    assert len(output_shape) == 3, 'Output shape needs to be (w, h, c), w = h'
    w, h, c = output_shape

    # TODO: generalize this to diffenent padding types (only works for
    # SAME right now) as well as different sized images (this only really
    # works for nice powers of stride length), i.e 32 -> 16 -> 8 -> 4
    if padding != 'SAME':
        raise NotImplementedError

    base_w = w // np.product(conv_strides)
    base_h = h // np.product(conv_strides)
    base_shape = (base_w, base_h, conv_filters[0])

    model = PicklableSequential([
        tfkl.Dense(
            units=np.product(base_shape),
            # activation=(layers.Activation(activation)
            #             if isinstance(activation, str)
            #             else activation()),
            kernel_regularizer=kernel_regularizer
        ),
        (layers.Activation(activation)
         if isinstance(activation, str)
         else activation()),
        tfkl.Reshape(target_shape=base_shape),
        *[
            conv_transpose_block(
                conv_filter,
                conv_kernel_size,
                conv_stride,
                activation,
                name=f'conv_transpose_block_{i}')
            for i, (conv_filter, conv_kernel_size, conv_stride) in
            enumerate(zip(conv_filters, conv_kernel_sizes, conv_strides))
        ],
        conv_transpose_block(
            n_filters=c,
            kernel_size=conv_kernel_sizes[-1],
            stride=1,
            block_activation=output_activation,
            name=f'conv_transpose_block_output',
        ),
    ], name=name)
    return model