in data_loaders/generate_tfr/generate.py [0:0]
def parse_celeba_image(max_res, transpose=False):
def _process_image(img):
img = tf.cast(img, 'float32')
resolution_log2 = int(np.log2(max_res))
q_imgs = []
for lod in range(resolution_log2 - 1):
if lod:
img = downsample(img)
quant = x_to_uint8(img)
q_imgs.append(quant)
return q_imgs
def _parse_image(example):
features = tf.parse_single_example(example, features={
'shape': tf.FixedLenFeature([3], tf.int64),
'data': tf.FixedLenFeature([], tf.string),
'attr': tf.FixedLenFeature([40], tf.int64)})
shape = features['shape']
data = features['data']
attr = features['attr']
data = tf.decode_raw(data, tf.uint8)
img = tf.reshape(data, shape)
if transpose:
img = tf.transpose(img, (1, 2, 0)) # CHW -> HWC
imgs = _process_image(img)
parsed = (attr, *imgs)
return parsed
return _parse_image