in models/hific/evaluate.py [0:0]
def eval_trained_model(config_name,
ckpt_dir,
out_dir,
images_glob,
tfds_arguments: helpers.TFDSArguments,
max_images=None):
"""Evaluate a trained model."""
config = configs.get_config(config_name)
hific = model.HiFiC(config, helpers.ModelMode.EVALUATION)
# Note: Automatically uses the validation split for TFDS.
dataset = hific.build_input(
batch_size=1,
crop_size=None,
images_glob=images_glob,
tfds_arguments=tfds_arguments)
image_names = get_image_names(images_glob)
iterator = tf.data.make_one_shot_iterator(dataset)
get_next_image = iterator.get_next()
input_image = get_next_image['input_image']
output_image, bitstring = hific.build_model(**get_next_image)
input_image = tf.cast(tf.round(input_image[0, ...]), tf.uint8)
output_image = tf.cast(tf.round(output_image[0, ...]), tf.uint8)
os.makedirs(out_dir, exist_ok=True)
accumulated_metrics = collections.defaultdict(list)
with tf.Session() as sess:
hific.restore_trained_model(sess, ckpt_dir)
hific.prepare_for_arithmetic_coding(sess)
for i in itertools.count(0):
if max_images and i == max_images:
break
try:
inp_np, otp_np, bitstring_np = \
sess.run([input_image, output_image, bitstring])
h, w, c = inp_np.shape
assert c == 3
bpp = get_arithmetic_coding_bpp(
bitstring, bitstring_np, num_pixels=h * w)
metrics = {'psnr': get_psnr(inp_np, otp_np),
'bpp_real': bpp}
metrics_str = ' / '.join(f'{metric}: {value:.5f}'
for metric, value in metrics.items())
print(f'Image {i: 4d}: {metrics_str}, saving in {out_dir}...')
for metric, value in metrics.items():
accumulated_metrics[metric].append(value)
# Save images.
name = image_names.get(i, f'img_{i:010d}')
Image.fromarray(inp_np).save(
os.path.join(out_dir, f'{name}_inp.png'))
Image.fromarray(otp_np).save(
os.path.join(out_dir, f'{name}_otp_{bpp:.3f}.png'))
except tf.errors.OutOfRangeError:
print('No more inputs.')
break
print('\n'.join(f'{metric}: {np.mean(values)}'
for metric, values in accumulated_metrics.items()))
print('Done!')