in tensorflow_graphics/projects/gan/architectures_progressive_gan.py [0:0]
def create_generator(latent_code_dimension: int = 128,
upsampling_blocks_num_channels: Sequence[int] = (512, 256,
128, 64),
relu_leakiness: float = 0.2,
kernel_initializer: Optional[_KerasInitializer] = None,
use_pixel_normalization: bool = True,
use_batch_normalization: bool = False,
generate_intermediate_outputs: bool = False,
normalize_latent_code: bool = True,
name: str = 'progressive_gan_generator') -> tf.keras.Model:
"""Creates a Keras model for the generator network architecture.
This architecture is implemented according to the paper "Progressive growing
of GANs for Improved Quality, Stability, and Variation"
https://arxiv.org/abs/1710.10196
The intermediate outputs are optionally provided for the architecture of
"MSG-GAN: Multi-Scale Gradient GAN for Stable Image Synthesis"
https://arxiv.org/abs/1903.06048
Args:
latent_code_dimension: The number of dimensions in the latent code.
upsampling_blocks_num_channels: The number of channels for each upsampling
block. This argument also determines how many upsampling blocks are added.
relu_leakiness: Slope of the negative part of the leaky relu.
kernel_initializer: Initializer of the kernel. If none TruncatedNormal is
used.
use_pixel_normalization: If pixel normalization layers should be inserted to
the network.
use_batch_normalization: If batch normalization layers should be inserted to
the network.
generate_intermediate_outputs: If true the model outputs a list of
tf.Tensors with increasing resolution starting with the starting_size up
to the final resolution output.
normalize_latent_code: If true the latent code is normalized to unit length
before feeding it to the network.
name: The name of the Keras model.
Returns:
The created generator keras model object.
"""
if kernel_initializer is None:
kernel_initializer = tf.keras.initializers.TruncatedNormal(
mean=0.0, stddev=1.0)
input_tensor = tf.keras.Input(shape=(latent_code_dimension,))
if normalize_latent_code:
maybe_normzlized_input_tensor = keras_layers.PixelNormalization(axis=1)(
input_tensor)
else:
maybe_normzlized_input_tensor = input_tensor
tensor = keras_layers.FanInScaledDense(
multiplier=math.sqrt(2.0) / 4.0,
units=4 * 4 * latent_code_dimension,
kernel_initializer=kernel_initializer)(
maybe_normzlized_input_tensor)
tensor = tf.keras.layers.Reshape(target_shape=(4, 4, latent_code_dimension))(
tensor)
tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor)
if use_batch_normalization:
tensor = tf.keras.layers.BatchNormalization()(tensor)
if use_pixel_normalization:
tensor = keras_layers.PixelNormalization(axis=3)(tensor)
tensor = keras_layers.FanInScaledConv2D(
filters=upsampling_blocks_num_channels[0],
kernel_size=3,
strides=1,
padding='same',
kernel_initializer=kernel_initializer)(
tensor)
tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor)
if use_batch_normalization:
tensor = tf.keras.layers.BatchNormalization()(tensor)
if use_pixel_normalization:
tensor = keras_layers.PixelNormalization(axis=3)(tensor)
outputs = []
for index, channels in enumerate(upsampling_blocks_num_channels):
if generate_intermediate_outputs:
outputs.append(
to_rgb(
input_tensor=tensor,
kernel_initializer=kernel_initializer,
name='side_output_%d_conv' % index))
tensor = keras_layers.TwoByTwoNearestNeighborUpSampling()(tensor)
for _ in range(2):
tensor = keras_layers.FanInScaledConv2D(
filters=channels,
kernel_size=3,
strides=1,
padding='same',
kernel_initializer=kernel_initializer)(
tensor)
tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor)
if use_batch_normalization:
tensor = tf.keras.layers.BatchNormalization()(tensor)
if use_pixel_normalization:
tensor = keras_layers.PixelNormalization(axis=3)(tensor)
tensor = to_rgb(
input_tensor=tensor,
kernel_initializer=kernel_initializer,
name='final_output')
if generate_intermediate_outputs:
outputs.append(tensor)
return tf.keras.Model(inputs=input_tensor, outputs=outputs, name=name)
else:
return tf.keras.Model(inputs=input_tensor, outputs=tensor, name=name)