def dump_celebahq()

in data_loaders/generate_tfr/generate.py [0:0]


def dump_celebahq(data_dir, tfrecord_dir, max_res, split, write):
    _NUM_IMAGES = {
        'train': 27000,
        'validation': 3000,
    }

    _NUM_SHARDS = {
        'train': 120,
        'validation': 40,
    }
    resolution_log2 = int(np.log2(max_res))
    if max_res != 2 ** resolution_log2:
        error('Input image resolution must be a power-of-two')
    with tf.Session() as sess:
        print("Reading data from ", data_dir)
        if split:
            tfr_files = get_tfr_files(data_dir, split, int(np.log2(max_res)))
            files = tf.data.Dataset.list_files(tfr_files)
            dset = files.apply(tf.contrib.data.parallel_interleave(
                tf.data.TFRecordDataset, cycle_length=_NUM_PARALLEL_FILE_READERS))
            transpose = False
        else:
            tfr_file = get_tfr_file(data_dir, "", int(np.log2(max_res)))
            dset = tf.data.TFRecordDataset(tfr_file, compression_type='')
            transpose = True

        parse_fn = parse_celeba_image(max_res, transpose)
        dset = dset.map(parse_fn, num_parallel_calls=_NUM_PARALLEL_MAP_CALLS)
        dset = dset.prefetch(1)
        iterator = dset.make_one_shot_iterator()
        _attr, *_imgs = iterator.get_next()
        sess.run(tf.global_variables_initializer())
        splits = [split] if split else ["validation", "train"]
        for split in splits:
            total_imgs = _NUM_IMAGES[split]
            shards = _NUM_SHARDS[split]
            with TFRecordExporter(os.path.join(tfrecord_dir, split), resolution_log2, total_imgs, shards) as tfr:
                for _ in tqdm(range(total_imgs)):
                    attr, *imgs = sess.run([_attr, *_imgs])
                    if write:
                        tfr.add_image(0, imgs, attr)
                if write:
                    assert tfr.cur_images == total_imgs, (
                        tfr.cur_images, total_imgs)