in data_loaders/generate_tfr/generate.py [0:0]
def __init__(self, tfrecord_dir, resolution_log2, expected_images, shards, print_progress=True, progress_interval=10):
self.tfrecord_dir = tfrecord_dir
self.tfr_prefix = os.path.join(
self.tfrecord_dir, os.path.basename(self.tfrecord_dir))
self.resolution_log2 = resolution_log2
self.expected_images = expected_images
self.cur_images = 0
self.shape = None
self.tfr_writers = []
self.print_progress = print_progress
self.progress_interval = progress_interval
if self.print_progress:
print('Creating dataset "%s"' % tfrecord_dir)
if not os.path.isdir(self.tfrecord_dir):
os.makedirs(self.tfrecord_dir)
assert (os.path.isdir(self.tfrecord_dir))
tfr_opt = tf.python_io.TFRecordOptions(
tf.python_io.TFRecordCompressionType.NONE)
for lod in range(self.resolution_log2 - 1):
p_shard = np.array_split(
np.random.permutation(expected_images), shards)
img_to_shard = np.zeros(expected_images, dtype=np.int)
writers = []
for shard in range(shards):
img_to_shard[p_shard[shard]] = shard
tfr_file = self.tfr_prefix + \
'-r%02d-s-%04d-of-%04d.tfrecords' % (
self.resolution_log2 - lod, shard, shards)
writers.append(tf.python_io.TFRecordWriter(tfr_file, tfr_opt))
#print(np.unique(img_to_shard, return_counts=True))
counts = np.unique(img_to_shard, return_counts=True)[1]
assert len(counts) == shards
print("Smallest and largest shards have size",
np.min(counts), np.max(counts))
self.tfr_writers.append((writers, img_to_shard))