def parse_image()

in data_loaders/generate_tfr/generate.py [0:0]


def parse_image(max_res):
    def _process_image(img):
        img = centre_crop(img)
        img = tf.image.resize_images(
            img, [max_res, max_res], method=_DOWNSAMPLING)
        img = tf.cast(img, 'float32')
        resolution_log2 = int(np.log2(max_res))
        q_imgs = []
        for lod in range(resolution_log2 - 1):
            if lod:
                img = downsample(img)
            quant = x_to_uint8(img)
            q_imgs.append(quant)
        return q_imgs

    def _parse_image(example):
        feature_map = {
            'image/encoded': tf.FixedLenFeature([], dtype=tf.string,
                                                default_value=''),
            'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64,
                                                    default_value=-1)
        }
        features = tf.parse_single_example(example, feature_map)
        img, label = features['image/encoded'], features['image/class/label']
        label = tf.cast(tf.reshape(label, shape=[]), dtype=tf.int32) - 1
        img = tf.image.decode_jpeg(img, channels=_NUM_CHANNELS)
        imgs = _process_image(img)
        parsed = (label, *imgs)
        return parsed

    return _parse_image