in data_loaders/get_data.py [0:0]
def input_fn(tfr_file, shards, rank, pmap, fmap, n_batch, resolution, rnd_crop, is_training):
files = tf.data.Dataset.list_files(tfr_file)
if ('lsun' not in tfr_file) or is_training:
# For 'lsun' validation, only one shard and each machine goes over the full dataset
# each worker works on a subset of the data
files = files.shard(shards, rank)
if is_training:
# shuffle order of files in shard
files = files.shuffle(buffer_size=_FILES_SHUFFLE)
dset = files.apply(tf.contrib.data.parallel_interleave(
tf.data.TFRecordDataset, cycle_length=fmap))
if is_training:
dset = dset.shuffle(buffer_size=n_batch * _SHUFFLE_FACTOR)
dset = dset.repeat()
dset = dset.map(lambda x: parse_tfrecord_tf(
x, resolution, rnd_crop), num_parallel_calls=pmap)
dset = dset.batch(n_batch)
dset = dset.prefetch(1)
itr = dset.make_one_shot_iterator()
return itr