def build_model()

in models/hific/model.py [0:0]


  def build_model(self, input_image, input_images_d_steps=None):
    """Build model and losses and train_ops.

    Args:
      input_image: A single (B, H, W, C) image, in [0, 255]
      input_images_d_steps: If training a discriminator, this is expected to
        be a (B*N, H, W, C) stack of images, where N=number of sub batches.
        See build_input.

    Returns:
      output_image and bitstrings if self.evaluation else None.
    """
    if input_images_d_steps is None:
      input_images_d_steps = []
    else:
      input_images_d_steps.set_shape(
          self.input_spec["input_images_d_steps"].shape)
      input_images_d_steps = tf.split(input_images_d_steps, self.num_steps_disc)

    if self.evaluation and input_images_d_steps:
      raise ValueError("Only need input_image for eval! {}".format(
          input_images_d_steps))

    input_image.set_shape(self.input_spec["input_image"].shape)

    self.build_transforms()

    if self.training:
      self._lpips_loss = LPIPSLoss(self._lpips_weight_path)
      self._lpips_loss_weight = self._config.loss_config.lpips_weight

    if self._setup_discriminator:
      self.build_discriminator()

    # Global step needs to be created for train, val and eval.
    global_step = tf.train.get_or_create_global_step()

    # Compute output graph.
    nodes_gen, bpp_pair, bitstrings = \
      self._compute_compression_graph(input_image)

    if self.evaluation:
      tf.logging.info("Evaluation mode: build_model done.")
      reconstruction = tf.clip_by_value(nodes_gen.reconstruction, 0, 255.)
      return reconstruction, bitstrings

    nodes_disc = []  # list of Nodes, one for every sub-batch of disc
    for i, sub_batch in enumerate(input_images_d_steps):
      with tf.name_scope("sub_batch_disc_{}".format(i)):
        nodes, _, _ = self._compute_compression_graph(
            sub_batch, create_summaries=False)
        nodes_disc.append(nodes)

    if self._auto_encoder_ckpt_path:
      self._prepare_auto_encoder_restore()

    # The following is inspired by compare_gan/gans/modular_gan.py:
    # Let's say we want to train the discriminator for D steps for every 1 step
    # of generator training. We do the unroll_graph=True options:
    # The features given to the model_fn are split into
    # D + 1 sub-batches. The code then creates D train_ops for the
    # discriminator, each feeding a different sub-batch of features
    # into the discriminator.
    # The train_op for the generator then depends on all these D train_ops
    # and uses the last (D+1 th) sub-batch.
    # Note that the graph is only created once.

    d_train_ops = []
    if self._setup_discriminator:
      tf.logging.info("Unrolling graph for discriminator")
      self._global_step_disc = tf.get_variable(
          "global_step_disc", [], dtype=global_step.dtype, trainable=False)
      with tf.name_scope("steps"):
        tf.summary.scalar("global_step", global_step)
        tf.summary.scalar("global_step_disc", self._global_step_disc)

      # Create optimizer once, and then call minimize on it multiple times
      # within self._train_discriminator.
      disc_optimizer = self._make_discriminator_optimizer(
          self._global_step_disc)
      for i, nodes in enumerate(nodes_disc):
        with tf.name_scope("train_disc_{}".format(i + 1)):
          with tf.control_dependencies(d_train_ops):
            d_train_ops.append(
                self._train_discriminator(
                    nodes, disc_optimizer, create_summaries=(i == 0)))

    # Depend on `d_train_ops`, which ensures all `self._num_steps_disc` steps of
    # the discriminator will run before the generator training op.
    with tf.control_dependencies(d_train_ops):
      train_op = self._train_generator(nodes_gen, bpp_pair, global_step)

    if self.training:
      self._train_op = train_op