in data_loaders/generate_tfr/generate.py [0:0]
def dump_imagenet(data_dir, tfrecord_dir, max_res, split, write):
_NUM_IMAGES = {
'train': 1281167,
'validation': 50000,
}
_NUM_FILES = _NUM_SHARDS = {
'train': 2000,
'validation': 80,
}
resolution_log2 = int(np.log2(max_res))
if max_res != 2 ** resolution_log2:
error('Input image resolution must be a power-of-two')
with tf.Session() as sess:
is_training = (split == 'train')
if is_training:
files = tf.data.Dataset.list_files(
os.path.join(data_dir, 'train-*-of-01024'))
else:
files = tf.data.Dataset.list_files(
os.path.join(data_dir, 'validation-*-of-00128'))
files = files.shuffle(buffer_size=_NUM_FILES[split])
dataset = files.apply(tf.contrib.data.parallel_interleave(
tf.data.TFRecordDataset, cycle_length=_NUM_PARALLEL_FILE_READERS))
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
parse_fn = parse_image(max_res)
dataset = dataset.map(
parse_fn, num_parallel_calls=_NUM_PARALLEL_MAP_CALLS)
dataset = dataset.prefetch(1)
iterator = dataset.make_one_shot_iterator()
_label, *_imgs = iterator.get_next()
sess.run(tf.global_variables_initializer())
total_imgs = _NUM_IMAGES[split]
shards = _NUM_SHARDS[split]
tfrecord_dir = os.path.join(tfrecord_dir, split)
with TFRecordExporter(tfrecord_dir, resolution_log2, total_imgs, shards) as tfr:
for _ in tqdm(range(total_imgs)):
label, *imgs = sess.run([_label, *_imgs])
if write:
tfr.add_image(label, imgs, [])
assert tfr.cur_images == total_imgs, (tfr.cur_images, total_imgs)