def parse_celeba_image()

in data_loaders/generate_tfr/generate.py [0:0]


def parse_celeba_image(max_res, transpose=False):
    def _process_image(img):
        img = tf.cast(img, 'float32')
        resolution_log2 = int(np.log2(max_res))
        q_imgs = []
        for lod in range(resolution_log2 - 1):
            if lod:
                img = downsample(img)
            quant = x_to_uint8(img)
            q_imgs.append(quant)
        return q_imgs

    def _parse_image(example):
        features = tf.parse_single_example(example, features={
            'shape': tf.FixedLenFeature([3], tf.int64),
            'data': tf.FixedLenFeature([], tf.string),
            'attr': tf.FixedLenFeature([40], tf.int64)})
        shape = features['shape']
        data = features['data']
        attr = features['attr']
        data = tf.decode_raw(data, tf.uint8)
        img = tf.reshape(data, shape)
        if transpose:
            img = tf.transpose(img, (1, 2, 0))  # CHW -> HWC
        imgs = _process_image(img)
        parsed = (attr, *imgs)
        return parsed

    return _parse_image