data_loaders/generate_tfr/generate.py (263 lines of code) (raw):

# Copyright 2017 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """ Generate CelebA-HQ and Imagenet datasets For CelebA-HQ, first create original tfrecords file using https://github.com/tkarras/progressive_growing_of_gans/blob/master/dataset_tool.py For Imagenet, first create original tfrecords file using https://github.com/tensorflow/models/blob/master/research/inception/inception/data/build_imagenet_data.py Then, use this script to get our tfr file from those records. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import os import tensorflow as tf import numpy as np from tqdm import tqdm from typing import Iterable _NUM_CHANNELS = 3 _NUM_PARALLEL_FILE_READERS = 32 _NUM_PARALLEL_MAP_CALLS = 32 _DOWNSAMPLING = tf.image.ResizeMethod.BILINEAR _SHUFFLE_BUFFER = 1024 def _int64_feature(value): if not isinstance(value, Iterable): value = [value] return tf.train.Feature(int64_list=tf.train.Int64List(value=value)) def _bytes_feature(value): return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) def error(msg): print('Error: ' + msg) exit(1) def x_to_uint8(x): return tf.cast(tf.clip_by_value(tf.floor(x), 0, 255), 'uint8') def centre_crop(img): h, w = tf.shape(img)[0], tf.shape(img)[1] min_side = tf.minimum(h, w) h_offset = (h - min_side) // 2 w_offset = (w - min_side) // 2 return tf.image.crop_to_bounding_box(img, h_offset, w_offset, min_side, min_side) def downsample(img): return (img[0::2, 0::2, :] + img[0::2, 1::2, :] + img[1::2, 0::2, :] + img[1::2, 1::2, :]) * 0.25 def parse_image(max_res): def _process_image(img): img = centre_crop(img) img = tf.image.resize_images( img, [max_res, max_res], method=_DOWNSAMPLING) img = tf.cast(img, 'float32') resolution_log2 = int(np.log2(max_res)) q_imgs = [] for lod in range(resolution_log2 - 1): if lod: img = downsample(img) quant = x_to_uint8(img) q_imgs.append(quant) return q_imgs def _parse_image(example): feature_map = { 'image/encoded': tf.FixedLenFeature([], dtype=tf.string, default_value=''), 'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64, default_value=-1) } features = tf.parse_single_example(example, feature_map) img, label = features['image/encoded'], features['image/class/label'] label = tf.cast(tf.reshape(label, shape=[]), dtype=tf.int32) - 1 img = tf.image.decode_jpeg(img, channels=_NUM_CHANNELS) imgs = _process_image(img) parsed = (label, *imgs) return parsed return _parse_image def parse_celeba_image(max_res, transpose=False): def _process_image(img): img = tf.cast(img, 'float32') resolution_log2 = int(np.log2(max_res)) q_imgs = [] for lod in range(resolution_log2 - 1): if lod: img = downsample(img) quant = x_to_uint8(img) q_imgs.append(quant) return q_imgs def _parse_image(example): features = tf.parse_single_example(example, features={ 'shape': tf.FixedLenFeature([3], tf.int64), 'data': tf.FixedLenFeature([], tf.string), 'attr': tf.FixedLenFeature([40], tf.int64)}) shape = features['shape'] data = features['data'] attr = features['attr'] data = tf.decode_raw(data, tf.uint8) img = tf.reshape(data, shape) if transpose: img = tf.transpose(img, (1, 2, 0)) # CHW -> HWC imgs = _process_image(img) parsed = (attr, *imgs) return parsed return _parse_image def get_tfr_files(data_dir, split, lgres): data_dir = os.path.join(data_dir, split) tfr_prefix = os.path.join(data_dir, os.path.basename(data_dir)) tfr_files = tfr_prefix + '-r%02d-s-*-of-*.tfrecords' % (lgres) return tfr_files def get_tfr_file(data_dir, split, lgres): if split: data_dir = os.path.join(data_dir, split) tfr_prefix = os.path.join(data_dir, os.path.basename(data_dir)) tfr_file = tfr_prefix + '-r%02d.tfrecords' % (lgres) return tfr_file def dump_celebahq(data_dir, tfrecord_dir, max_res, split, write): _NUM_IMAGES = { 'train': 27000, 'validation': 3000, } _NUM_SHARDS = { 'train': 120, 'validation': 40, } resolution_log2 = int(np.log2(max_res)) if max_res != 2 ** resolution_log2: error('Input image resolution must be a power-of-two') with tf.Session() as sess: print("Reading data from ", data_dir) if split: tfr_files = get_tfr_files(data_dir, split, int(np.log2(max_res))) files = tf.data.Dataset.list_files(tfr_files) dset = files.apply(tf.contrib.data.parallel_interleave( tf.data.TFRecordDataset, cycle_length=_NUM_PARALLEL_FILE_READERS)) transpose = False else: tfr_file = get_tfr_file(data_dir, "", int(np.log2(max_res))) dset = tf.data.TFRecordDataset(tfr_file, compression_type='') transpose = True parse_fn = parse_celeba_image(max_res, transpose) dset = dset.map(parse_fn, num_parallel_calls=_NUM_PARALLEL_MAP_CALLS) dset = dset.prefetch(1) iterator = dset.make_one_shot_iterator() _attr, *_imgs = iterator.get_next() sess.run(tf.global_variables_initializer()) splits = [split] if split else ["validation", "train"] for split in splits: total_imgs = _NUM_IMAGES[split] shards = _NUM_SHARDS[split] with TFRecordExporter(os.path.join(tfrecord_dir, split), resolution_log2, total_imgs, shards) as tfr: for _ in tqdm(range(total_imgs)): attr, *imgs = sess.run([_attr, *_imgs]) if write: tfr.add_image(0, imgs, attr) if write: assert tfr.cur_images == total_imgs, ( tfr.cur_images, total_imgs) #attr, *imgs = sess.run([_attr, *_imgs]) def dump_imagenet(data_dir, tfrecord_dir, max_res, split, write): _NUM_IMAGES = { 'train': 1281167, 'validation': 50000, } _NUM_FILES = _NUM_SHARDS = { 'train': 2000, 'validation': 80, } resolution_log2 = int(np.log2(max_res)) if max_res != 2 ** resolution_log2: error('Input image resolution must be a power-of-two') with tf.Session() as sess: is_training = (split == 'train') if is_training: files = tf.data.Dataset.list_files( os.path.join(data_dir, 'train-*-of-01024')) else: files = tf.data.Dataset.list_files( os.path.join(data_dir, 'validation-*-of-00128')) files = files.shuffle(buffer_size=_NUM_FILES[split]) dataset = files.apply(tf.contrib.data.parallel_interleave( tf.data.TFRecordDataset, cycle_length=_NUM_PARALLEL_FILE_READERS)) dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER) parse_fn = parse_image(max_res) dataset = dataset.map( parse_fn, num_parallel_calls=_NUM_PARALLEL_MAP_CALLS) dataset = dataset.prefetch(1) iterator = dataset.make_one_shot_iterator() _label, *_imgs = iterator.get_next() sess.run(tf.global_variables_initializer()) total_imgs = _NUM_IMAGES[split] shards = _NUM_SHARDS[split] tfrecord_dir = os.path.join(tfrecord_dir, split) with TFRecordExporter(tfrecord_dir, resolution_log2, total_imgs, shards) as tfr: for _ in tqdm(range(total_imgs)): label, *imgs = sess.run([_label, *_imgs]) if write: tfr.add_image(label, imgs, []) assert tfr.cur_images == total_imgs, (tfr.cur_images, total_imgs) #label, *imgs = sess.run([_label, *_imgs]) class TFRecordExporter: def __init__(self, tfrecord_dir, resolution_log2, expected_images, shards, print_progress=True, progress_interval=10): self.tfrecord_dir = tfrecord_dir self.tfr_prefix = os.path.join( self.tfrecord_dir, os.path.basename(self.tfrecord_dir)) self.resolution_log2 = resolution_log2 self.expected_images = expected_images self.cur_images = 0 self.shape = None self.tfr_writers = [] self.print_progress = print_progress self.progress_interval = progress_interval if self.print_progress: print('Creating dataset "%s"' % tfrecord_dir) if not os.path.isdir(self.tfrecord_dir): os.makedirs(self.tfrecord_dir) assert (os.path.isdir(self.tfrecord_dir)) tfr_opt = tf.python_io.TFRecordOptions( tf.python_io.TFRecordCompressionType.NONE) for lod in range(self.resolution_log2 - 1): p_shard = np.array_split( np.random.permutation(expected_images), shards) img_to_shard = np.zeros(expected_images, dtype=np.int) writers = [] for shard in range(shards): img_to_shard[p_shard[shard]] = shard tfr_file = self.tfr_prefix + \ '-r%02d-s-%04d-of-%04d.tfrecords' % ( self.resolution_log2 - lod, shard, shards) writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt)) #print(np.unique(img_to_shard, return_counts=True)) counts = np.unique(img_to_shard, return_counts=True)[1] assert len(counts) == shards print("Smallest and largest shards have size", np.min(counts), np.max(counts)) self.tfr_writers.append((writers, img_to_shard)) def close(self): if self.print_progress: print('%-40s\r' % 'Flushing data...', end='', flush=True) for (writers, _) in self.tfr_writers: for writer in writers: writer.close() self.tfr_writers = [] if self.print_progress: print('%-40s\r' % '', end='', flush=True) print('Added %d images.' % self.cur_images) def add_image(self, label, imgs, attr): assert len(imgs) == len(self.tfr_writers) # if self.print_progress and self.cur_images % self.progress_interval == 0: # print('%d / %d\r' % (self.cur_images, self.expected_images), end='', flush=True) for lod, (writers, img_to_shard) in enumerate(self.tfr_writers): quant = imgs[lod] size = 2 ** (self.resolution_log2 - lod) assert quant.shape == (size, size, 3), quant.shape ex = tf.train.Example( features=tf.train.Features( feature={ 'shape': _int64_feature(quant.shape), 'data': _bytes_feature(quant.tostring()), 'label': _int64_feature(label), 'attr': _int64_feature(attr) } ) ) writers[img_to_shard[self.cur_images]].write( ex.SerializeToString()) self.cur_images += 1 # def add_labels(self, labels): # if self.print_progress: # print('%-40s\r' % 'Saving labels...', end='', flush=True) # assert labels.shape[0] == self.cur_images # with open(self.tfr_prefix + '-rxx.labels', 'wb') as f: # np.save(f, labels.astype(np.float32)) def __enter__(self): return self def __exit__(self, *args): self.close() if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--data_dir", type=str, required=True) parser.add_argument("--max_res", type=int, default=256, help="Image size") parser.add_argument("--tfrecord_dir", type=str, required=True, help='place to dump') parser.add_argument("--write", action='store_true', help="Whether to write") hps = parser.parse_args() # So error if typo #dump_imagenet(hps.data_dir, hps.tfrecord_dir, hps.max_res, 'validation', hps.write) #dump_imagenet(hps.data_dir, hps.tfrecord_dir, hps.max_res, 'train', hps.write) dump_celebahq(hps.data_dir, hps.tfrecord_dir, hps.max_res, 'validation', hps.write) dump_celebahq(hps.data_dir, hps.tfrecord_dir, hps.max_res, 'train', hps.write)