in data_loaders/generate_tfr/generate.py [0:0]
def add_image(self, label, imgs, attr):
assert len(imgs) == len(self.tfr_writers)
# if self.print_progress and self.cur_images % self.progress_interval == 0:
# print('%d / %d\r' % (self.cur_images, self.expected_images), end='', flush=True)
for lod, (writers, img_to_shard) in enumerate(self.tfr_writers):
quant = imgs[lod]
size = 2 ** (self.resolution_log2 - lod)
assert quant.shape == (size, size, 3), quant.shape
ex = tf.train.Example(
features=tf.train.Features(
feature={
'shape': _int64_feature(quant.shape),
'data': _bytes_feature(quant.tostring()),
'label': _int64_feature(label),
'attr': _int64_feature(attr)
}
)
)
writers[img_to_shard[self.cur_images]].write(
ex.SerializeToString())
self.cur_images += 1