in models/hific/model.py [0:0]
def build_input(self,
batch_size,
crop_size,
images_glob=None,
tfds_arguments: TFDSArguments = None):
"""Build input dataset."""
if not (images_glob or tfds_arguments):
raise ValueError("Need images_glob or tfds_arguments!")
if self._setup_discriminator:
# Unroll dataset for GAN training. If we unroll for N steps,
# we want to fetch (N+1) batches for every step, where 1 batch
# will be used for G training, and the remaining N batches for D training.
batch_size *= (self._num_steps_disc + 1)
if self._setup_discriminator:
# Split the (N+1) batches into two arguments for build_model.
def _batch_to_dict(batch):
num_sub_batches = self._num_steps_disc + 1
sub_batch_size = batch_size // num_sub_batches
splits = [sub_batch_size, sub_batch_size * self._num_steps_disc]
input_image, input_images_d_steps = tf.split(batch, splits)
return dict(input_image=input_image,
input_images_d_steps=input_images_d_steps)
else:
def _batch_to_dict(batch):
return dict(input_image=batch)
dataset = self._get_dataset(batch_size, crop_size,
images_glob, tfds_arguments)
return dataset.map(_batch_to_dict)