def main()

in tensorflow_graphics/projects/radiance_fields/nerf/eval.py [0:0]


def main(_):

  dataset, height, width = data_loaders.load_synthetic_nerf_dataset(
      dataset_dir=FLAGS.dataset_dir,
      dataset_name=FLAGS.dataset_name,
      split=FLAGS.split,
      scale=FLAGS.dataset_scale,
      batch_size=1,
      shuffle=False)

  model = model_lib.NeRF(
      ray_samples_coarse=FLAGS.ray_samples_coarse,
      ray_samples_fine=FLAGS.ray_samples_fine,
      near=FLAGS.near,
      far=FLAGS.far,
      n_freq_posenc_xyz=FLAGS.n_freq_posenc_xyz,
      scene_bbox=tuple([float(x) for x in FLAGS.scene_bbox.split(',')]),
      n_freq_posenc_dir=FLAGS.n_freq_posenc_dir,
      n_filters=FLAGS.n_filters,
      white_background=True)
  model.init_coarse_and_fine_models()
  model.init_optimizer(learning_rate=FLAGS.learning_rate)
  model.init_checkpoint(checkpoint_dir=FLAGS.checkpoint_dir)

  if not tf.io.gfile.exists(FLAGS.output_dir):
    tf.io.gfile.makedirs(FLAGS.output_dir)
  summary_writer = tf.summary.create_file_writer(FLAGS.output_dir)

  # ----------------------------------------------------------------------------
  current_evaluation = 0
  current_checkpoint = ''
  while True:
    latest_checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)

    if latest_checkpoint is None:
      continue

    if current_checkpoint == latest_checkpoint:
      continue

    current_checkpoint = latest_checkpoint
    model.load_checkpoint(current_checkpoint)

    total_psnr = []
    total_ssim = []
    image_counter = 0
    for image, focal, principal_point, transform_matrix in dataset:

      img_rays, _ = perspective.random_patches(
          focal,
          principal_point,
          height,
          width,
          patch_height=height,
          patch_width=width,
          scale=1.0)

      # Batchify the image to fit into memory
      batch_rays = tf.split(img_rays, height, axis=1)
      output = []
      for random_rays in batch_rays:
        random_rays = utils.change_coordinate_system(random_rays,
                                                     (0., 0., 0.),
                                                     (1., -1., -1.))
        rays_org, rays_dir = utils.camera_rays_from_transformation_matrix(
            random_rays,
            transform_matrix)

        rgb_fine, *_ = model.inference(rays_org, rays_dir)
        output.append(rgb_fine)
      final_image = tf.concat(output, axis=0)
      final_image_np = final_image.numpy()

      image_rgb_no_alpha, image_a = tf.split(image, [3, 1], axis=-1)
      if FLAGS.white_background:
        image = image_rgb_no_alpha * image_a + 1 - image_a

      image_np = image.numpy()[0]
      ssim = metrics.structural_similarity(image_np,
                                           final_image_np,
                                           multichannel=True,
                                           data_range=1)
      psnr = metrics.peak_signal_noise_ratio(image_np,
                                             final_image_np,
                                             data_range=1)
      total_psnr.append(psnr)
      total_ssim.append(ssim)

      filename = os.path.join(FLAGS.output_dir,
                              '{0:05d}.png'.format(image_counter))
      img_to_save = Image.fromarray((final_image_np*255).astype(np.uint8))
      with tf.io.gfile.GFile(filename, 'wb') as f:
        img_to_save.save(f)

      logging.info('Image %d: ssim %.3f / psnr: %.3f',
                   image_counter, ssim, psnr)
      image_counter += 1

      # Show some images
      if image_counter < 5:
        with summary_writer.as_default():
          tf.summary.image('rgb_fine/{0}'.format(image_counter),
                           tf.expand_dims(final_image, 0),
                           step=current_evaluation,
                           max_outputs=4)
    with summary_writer.as_default():
      tf.summary.scalar('eval_ssim', np.mean(total_ssim),
                        step=current_evaluation)
      tf.summary.scalar('eval_psnr', np.mean(total_psnr),
                        step=current_evaluation)
    logging.info('ssim %.3f', np.mean(total_ssim))
    logging.info('psnr %.3f', np.mean(total_psnr))
    current_evaluation += 1

    if FLAGS.single_eval:
      break