in data_loaders/generate_tfr/generate.py [0:0]
def parse_image(max_res):
def _process_image(img):
img = centre_crop(img)
img = tf.image.resize_images(
img, [max_res, max_res], method=_DOWNSAMPLING)
img = tf.cast(img, 'float32')
resolution_log2 = int(np.log2(max_res))
q_imgs = []
for lod in range(resolution_log2 - 1):
if lod:
img = downsample(img)
quant = x_to_uint8(img)
q_imgs.append(quant)
return q_imgs
def _parse_image(example):
feature_map = {
'image/encoded': tf.FixedLenFeature([], dtype=tf.string,
default_value=''),
'image/class/label': tf.FixedLenFeature([1], dtype=tf.int64,
default_value=-1)
}
features = tf.parse_single_example(example, feature_map)
img, label = features['image/encoded'], features['image/class/label']
label = tf.cast(tf.reshape(label, shape=[]), dtype=tf.int32) - 1
img = tf.image.decode_jpeg(img, channels=_NUM_CHANNELS)
imgs = _process_image(img)
parsed = (label, *imgs)
return parsed
return _parse_image