def create_generator()

in tensorflow_graphics/projects/gan/architectures_progressive_gan.py [0:0]


def create_generator(latent_code_dimension: int = 128,
                     upsampling_blocks_num_channels: Sequence[int] = (512, 256,
                                                                      128, 64),
                     relu_leakiness: float = 0.2,
                     kernel_initializer: Optional[_KerasInitializer] = None,
                     use_pixel_normalization: bool = True,
                     use_batch_normalization: bool = False,
                     generate_intermediate_outputs: bool = False,
                     normalize_latent_code: bool = True,
                     name: str = 'progressive_gan_generator') -> tf.keras.Model:
  """Creates a Keras model for the generator network architecture.

  This architecture is implemented according to the paper "Progressive growing
  of GANs for Improved Quality, Stability, and Variation"
  https://arxiv.org/abs/1710.10196
  The intermediate outputs are optionally provided for the architecture of
  "MSG-GAN: Multi-Scale Gradient GAN for Stable Image Synthesis"
  https://arxiv.org/abs/1903.06048

  Args:
    latent_code_dimension: The number of dimensions in the latent code.
    upsampling_blocks_num_channels: The number of channels for each upsampling
      block. This argument also determines how many upsampling blocks are added.
    relu_leakiness: Slope of the negative part of the leaky relu.
    kernel_initializer: Initializer of the kernel. If none TruncatedNormal is
      used.
    use_pixel_normalization: If pixel normalization layers should be inserted to
      the network.
    use_batch_normalization: If batch normalization layers should be inserted to
      the network.
    generate_intermediate_outputs: If true the model outputs a list of
      tf.Tensors with increasing resolution starting with the starting_size up
      to the final resolution output.
    normalize_latent_code: If true the latent code is normalized to unit length
      before feeding it to the network.
    name: The name of the Keras model.

  Returns:
     The created generator keras model object.
  """
  if kernel_initializer is None:
    kernel_initializer = tf.keras.initializers.TruncatedNormal(
        mean=0.0, stddev=1.0)

  input_tensor = tf.keras.Input(shape=(latent_code_dimension,))
  if normalize_latent_code:
    maybe_normzlized_input_tensor = keras_layers.PixelNormalization(axis=1)(
        input_tensor)
  else:
    maybe_normzlized_input_tensor = input_tensor

  tensor = keras_layers.FanInScaledDense(
      multiplier=math.sqrt(2.0) / 4.0,
      units=4 * 4 * latent_code_dimension,
      kernel_initializer=kernel_initializer)(
          maybe_normzlized_input_tensor)
  tensor = tf.keras.layers.Reshape(target_shape=(4, 4, latent_code_dimension))(
      tensor)
  tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor)
  if use_batch_normalization:
    tensor = tf.keras.layers.BatchNormalization()(tensor)
  if use_pixel_normalization:
    tensor = keras_layers.PixelNormalization(axis=3)(tensor)
  tensor = keras_layers.FanInScaledConv2D(
      filters=upsampling_blocks_num_channels[0],
      kernel_size=3,
      strides=1,
      padding='same',
      kernel_initializer=kernel_initializer)(
          tensor)
  tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor)
  if use_batch_normalization:
    tensor = tf.keras.layers.BatchNormalization()(tensor)
  if use_pixel_normalization:
    tensor = keras_layers.PixelNormalization(axis=3)(tensor)

  outputs = []
  for index, channels in enumerate(upsampling_blocks_num_channels):
    if generate_intermediate_outputs:
      outputs.append(
          to_rgb(
              input_tensor=tensor,
              kernel_initializer=kernel_initializer,
              name='side_output_%d_conv' % index))
    tensor = keras_layers.TwoByTwoNearestNeighborUpSampling()(tensor)

    for _ in range(2):
      tensor = keras_layers.FanInScaledConv2D(
          filters=channels,
          kernel_size=3,
          strides=1,
          padding='same',
          kernel_initializer=kernel_initializer)(
              tensor)
      tensor = tf.keras.layers.LeakyReLU(alpha=relu_leakiness)(tensor)
      if use_batch_normalization:
        tensor = tf.keras.layers.BatchNormalization()(tensor)
      if use_pixel_normalization:
        tensor = keras_layers.PixelNormalization(axis=3)(tensor)

  tensor = to_rgb(
      input_tensor=tensor,
      kernel_initializer=kernel_initializer,
      name='final_output')
  if generate_intermediate_outputs:
    outputs.append(tensor)

    return tf.keras.Model(inputs=input_tensor, outputs=outputs, name=name)
  else:
    return tf.keras.Model(inputs=input_tensor, outputs=tensor, name=name)