def convnet()

in MTRF/algorithms/softlearning/preprocessors/convnet_preprocessor.py [0:0]


def convnet(input_shape,
            output_size,
            conv_filters=(64, 64, 64),
            conv_kernel_sizes=(3, 3, 3),
            conv_strides=(2, 2, 2),
            use_global_average_pool=False,
            normalization_type=None,
            normalization_kwargs={},
            downsampling_type='conv',
            name='convnet',
            *args,
            **kwargs):
    assert downsampling_type in ('pool', 'conv'), downsampling_type

    img_input = layers.Input(shape=input_shape, dtype=tf.float32)
    x = img_input

    for (conv_filter, conv_kernel_size, conv_stride) in zip(
            conv_filters, conv_kernel_sizes, conv_strides):
        x = layers.Conv2D(
            filters=conv_filter,
            kernel_size=conv_kernel_size,
            strides=(conv_stride if downsampling_type == 'conv' else 1),
            padding="SAME",
            activation='linear',
            *args,
            **kwargs
        )(x)

        if normalization_type == 'batch':
            x = layers.BatchNormalization(**normalization_kwargs)(x)
        elif normalization_type == 'layer':
            x = LayerNormalization(**normalization_kwargs)(x)
        elif normalization_type == 'group':
            x = GroupNormalization(**normalization_kwargs)(x)
        elif normalization_type == 'instance':
            x = InstanceNormalization(**normalization_kwargs)(x)
        elif normalization_type == 'weight':
            raise NotImplementedError(normalization_type)
        else:
            assert normalization_type is None, normalization_type

        x = layers.LeakyReLU()(x)

        if downsampling_type == 'pool' and conv_stride > 1:
            x = getattr(tf.keras.layers, 'AvgPool2D')(
                pool_size=conv_stride, strides=conv_stride
            )(x)

    if use_global_average_pool:
        x = layers.GlobalAveragePooling2D(name='average_pool')(x)
    else:
        x = tf.keras.layers.Flatten()(x)

    model = models.Model(img_input, x, name=name)
    model.summary()
    return model