def build_input()

in models/hific/model.py [0:0]


  def build_input(self,
                  batch_size,
                  crop_size,
                  images_glob=None,
                  tfds_arguments: TFDSArguments = None):
    """Build input dataset."""
    if not (images_glob or tfds_arguments):
      raise ValueError("Need images_glob or tfds_arguments!")

    if self._setup_discriminator:
      # Unroll dataset for GAN training. If we unroll for N steps,
      # we want to fetch (N+1) batches for every step, where 1 batch
      # will be used for G training, and the remaining N batches for D training.
      batch_size *= (self._num_steps_disc + 1)

    if self._setup_discriminator:
      # Split the (N+1) batches into two arguments for build_model.
      def _batch_to_dict(batch):
        num_sub_batches = self._num_steps_disc + 1
        sub_batch_size = batch_size // num_sub_batches
        splits = [sub_batch_size, sub_batch_size * self._num_steps_disc]
        input_image, input_images_d_steps = tf.split(batch, splits)
        return dict(input_image=input_image,
                    input_images_d_steps=input_images_d_steps)
    else:
      def _batch_to_dict(batch):
        return dict(input_image=batch)

    dataset = self._get_dataset(batch_size, crop_size,
                                images_glob, tfds_arguments)
    return dataset.map(_batch_to_dict)