def dump_imagenet()

in data_loaders/generate_tfr/generate.py [0:0]


def dump_imagenet(data_dir, tfrecord_dir, max_res, split, write):
    _NUM_IMAGES = {
        'train': 1281167,
        'validation': 50000,
    }

    _NUM_FILES = _NUM_SHARDS = {
        'train': 2000,
        'validation': 80,
    }
    resolution_log2 = int(np.log2(max_res))
    if max_res != 2 ** resolution_log2:
        error('Input image resolution must be a power-of-two')

    with tf.Session() as sess:
        is_training = (split == 'train')
        if is_training:
            files = tf.data.Dataset.list_files(
                os.path.join(data_dir, 'train-*-of-01024'))
        else:
            files = tf.data.Dataset.list_files(
                os.path.join(data_dir, 'validation-*-of-00128'))

        files = files.shuffle(buffer_size=_NUM_FILES[split])

        dataset = files.apply(tf.contrib.data.parallel_interleave(
            tf.data.TFRecordDataset, cycle_length=_NUM_PARALLEL_FILE_READERS))

        dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
        parse_fn = parse_image(max_res)
        dataset = dataset.map(
            parse_fn, num_parallel_calls=_NUM_PARALLEL_MAP_CALLS)
        dataset = dataset.prefetch(1)
        iterator = dataset.make_one_shot_iterator()

        _label, *_imgs = iterator.get_next()

        sess.run(tf.global_variables_initializer())

        total_imgs = _NUM_IMAGES[split]
        shards = _NUM_SHARDS[split]
        tfrecord_dir = os.path.join(tfrecord_dir, split)
        with TFRecordExporter(tfrecord_dir, resolution_log2, total_imgs, shards) as tfr:
            for _ in tqdm(range(total_imgs)):
                label, *imgs = sess.run([_label, *_imgs])
                if write:
                    tfr.add_image(label, imgs, [])
            assert tfr.cur_images == total_imgs, (tfr.cur_images, total_imgs)