in tensorflow_graphics/projects/gan/architectures_progressive_gan.py [0:0]
def create_discriminator(
downsampling_blocks_num_channels: Sequence[Sequence[int]] = ((64, 128),
(128, 128),
(256, 256),
(512, 512)),
relu_leakiness: float = 0.2,
kernel_initializer: Optional[_KerasInitializer] = None,
use_fan_in_scaled_kernels: bool = True,
use_layer_normalization: bool = False,
use_intermediate_inputs: bool = False,
use_antialiased_bilinear_downsampling: bool = False,
name: str = 'progressive_gan_discriminator'):
"""Creates a Keras model for the discriminator architecture.
This architecture is implemented according to the paper "Progressive growing
of GANs for Improved Quality, Stability, and Variation"
https://arxiv.org/abs/1710.10196
The intermediate outputs can optionally be given as input for the architecture
of "MSG-GAN: Multi-Scale Gradient GAN for Stable Image Synthesis"
https://arxiv.org/abs/1903.06048
Args:
downsampling_blocks_num_channels: The number of channels in the downsampling
blocks for each block the number of channels for the first and second
convolution are specified.
relu_leakiness: Slope of the negative part of the leaky relu.
kernel_initializer: Initializer of the kernel. If none TruncatedNormal is
used.
use_fan_in_scaled_kernels: This rescales the kernels using the scale factor
from the he initializer, which implements the equalized learning rate.
use_layer_normalization: If layer normalization layers should be inserted to
the network.
use_intermediate_inputs: If true the model expects a list of tf.Tensors with
increasing resolution starting with the starting_size up to the final
resolution as input.
use_antialiased_bilinear_downsampling: If true the downsampling operation is
ani-aliased bilinear downsampling with a [1, 3, 3, 1] tent kernel. If
false standard bilinear downsampling, i.e. average pooling is used ([1, 1]
tent kernel).
name: The name of the Keras model.
Returns:
The generated discriminator keras model.
"""
if kernel_initializer is None:
kernel_initializer = tf.keras.initializers.TruncatedNormal(
mean=0.0, stddev=1.0)
if use_intermediate_inputs:
inputs = tuple(
tf.keras.Input(shape=(None, None, 3))
for _ in range(len(downsampling_blocks_num_channels) + 1))
tensor = inputs[-1]
else:
input_tensor = tf.keras.Input(shape=(None, None, 3))
tensor = input_tensor
tensor = from_rgb(
tensor,
use_fan_in_scaled_kernel=use_fan_in_scaled_kernels,
num_channels=downsampling_blocks_num_channels[0][0],
kernel_initializer=kernel_initializer,
relu_leakiness=relu_leakiness)
if use_layer_normalization:
tensor = tfa_normalizations.GroupNormalization(groups=1)(tensor)
for index, (channels_1,
channels_2) in enumerate(downsampling_blocks_num_channels):
tensor = create_conv_layer(
use_fan_in_scaled_kernel=use_fan_in_scaled_kernels,
filters=channels_1,
kernel_size=3,
strides=1,
padding='same',
kernel_initializer=kernel_initializer)(
tensor)
tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor)
if use_layer_normalization:
tensor = tfa_normalizations.GroupNormalization(groups=1)(tensor)
tensor = create_conv_layer(
use_fan_in_scaled_kernel=use_fan_in_scaled_kernels,
filters=channels_2,
kernel_size=3,
strides=1,
padding='same',
kernel_initializer=kernel_initializer)(
tensor)
tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor)
if use_layer_normalization:
tensor = tfa_normalizations.GroupNormalization(groups=1)(tensor)
if use_antialiased_bilinear_downsampling:
tensor = keras_layers.Blur2D()(tensor)
tensor = tf.keras.layers.AveragePooling2D()(tensor)
if use_intermediate_inputs:
tensor = tf.keras.layers.Concatenate()([inputs[-index - 2], tensor])
tensor = create_conv_layer(
use_fan_in_scaled_kernel=use_fan_in_scaled_kernels,
filters=downsampling_blocks_num_channels[-1][1],
kernel_size=3,
strides=1,
padding='same',
kernel_initializer=kernel_initializer)(
tensor)
tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor)
if use_layer_normalization:
tensor = tfa_normalizations.GroupNormalization(groups=1)(tensor)
tensor = create_conv_layer(
use_fan_in_scaled_kernel=use_fan_in_scaled_kernels,
filters=downsampling_blocks_num_channels[-1][1],
kernel_size=4,
strides=1,
padding='valid',
kernel_initializer=kernel_initializer)(
tensor)
tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor)
if use_layer_normalization:
tensor = tfa_normalizations.GroupNormalization(groups=1)(tensor)
tensor = create_conv_layer(
use_fan_in_scaled_kernel=use_fan_in_scaled_kernels,
multiplier=1.0,
filters=1,
kernel_size=1,
kernel_initializer=kernel_initializer)(
tensor)
tensor = tf.keras.layers.Reshape((-1,))(tensor)
if use_intermediate_inputs:
return tf.keras.Model(inputs=inputs, outputs=tensor, name=name)
else:
return tf.keras.Model(inputs=input_tensor, outputs=tensor, name=name)