in data_loaders/generate_tfr/generate.py [0:0]
def dump_celebahq(data_dir, tfrecord_dir, max_res, split, write):
_NUM_IMAGES = {
'train': 27000,
'validation': 3000,
}
_NUM_SHARDS = {
'train': 120,
'validation': 40,
}
resolution_log2 = int(np.log2(max_res))
if max_res != 2 ** resolution_log2:
error('Input image resolution must be a power-of-two')
with tf.Session() as sess:
print("Reading data from ", data_dir)
if split:
tfr_files = get_tfr_files(data_dir, split, int(np.log2(max_res)))
files = tf.data.Dataset.list_files(tfr_files)
dset = files.apply(tf.contrib.data.parallel_interleave(
tf.data.TFRecordDataset, cycle_length=_NUM_PARALLEL_FILE_READERS))
transpose = False
else:
tfr_file = get_tfr_file(data_dir, "", int(np.log2(max_res)))
dset = tf.data.TFRecordDataset(tfr_file, compression_type='')
transpose = True
parse_fn = parse_celeba_image(max_res, transpose)
dset = dset.map(parse_fn, num_parallel_calls=_NUM_PARALLEL_MAP_CALLS)
dset = dset.prefetch(1)
iterator = dset.make_one_shot_iterator()
_attr, *_imgs = iterator.get_next()
sess.run(tf.global_variables_initializer())
splits = [split] if split else ["validation", "train"]
for split in splits:
total_imgs = _NUM_IMAGES[split]
shards = _NUM_SHARDS[split]
with TFRecordExporter(os.path.join(tfrecord_dir, split), resolution_log2, total_imgs, shards) as tfr:
for _ in tqdm(range(total_imgs)):
attr, *imgs = sess.run([_attr, *_imgs])
if write:
tfr.add_image(0, imgs, attr)
if write:
assert tfr.cur_images == total_imgs, (
tfr.cur_images, total_imgs)