def eval_trained_model()

in models/hific/evaluate.py [0:0]


def eval_trained_model(config_name,
                       ckpt_dir,
                       out_dir,
                       images_glob,
                       tfds_arguments: helpers.TFDSArguments,
                       max_images=None):
  """Evaluate a trained model."""
  config = configs.get_config(config_name)
  hific = model.HiFiC(config, helpers.ModelMode.EVALUATION)

  # Note: Automatically uses the validation split for TFDS.
  dataset = hific.build_input(
      batch_size=1,
      crop_size=None,
      images_glob=images_glob,
      tfds_arguments=tfds_arguments)
  image_names = get_image_names(images_glob)
  iterator = tf.data.make_one_shot_iterator(dataset)
  get_next_image = iterator.get_next()
  input_image = get_next_image['input_image']
  output_image, bitstring = hific.build_model(**get_next_image)

  input_image = tf.cast(tf.round(input_image[0, ...]), tf.uint8)
  output_image = tf.cast(tf.round(output_image[0, ...]), tf.uint8)

  os.makedirs(out_dir, exist_ok=True)

  accumulated_metrics = collections.defaultdict(list)

  with tf.Session() as sess:
    hific.restore_trained_model(sess, ckpt_dir)
    hific.prepare_for_arithmetic_coding(sess)

    for i in itertools.count(0):
      if max_images and i == max_images:
        break
      try:
        inp_np, otp_np, bitstring_np = \
          sess.run([input_image, output_image, bitstring])

        h, w, c = inp_np.shape
        assert c == 3
        bpp = get_arithmetic_coding_bpp(
            bitstring, bitstring_np, num_pixels=h * w)

        metrics = {'psnr': get_psnr(inp_np, otp_np),
                   'bpp_real': bpp}

        metrics_str = ' / '.join(f'{metric}: {value:.5f}'
                                 for metric, value in metrics.items())
        print(f'Image {i: 4d}: {metrics_str}, saving in {out_dir}...')

        for metric, value in metrics.items():
          accumulated_metrics[metric].append(value)

        # Save images.
        name = image_names.get(i, f'img_{i:010d}')
        Image.fromarray(inp_np).save(
            os.path.join(out_dir, f'{name}_inp.png'))
        Image.fromarray(otp_np).save(
            os.path.join(out_dir, f'{name}_otp_{bpp:.3f}.png'))

      except tf.errors.OutOfRangeError:
        print('No more inputs.')
        break

  print('\n'.join(f'{metric}: {np.mean(values)}'
                  for metric, values in accumulated_metrics.items()))
  print('Done!')