in tensorflow_graphics/projects/radiance_fields/nerf/eval.py [0:0]
def main(_):
dataset, height, width = data_loaders.load_synthetic_nerf_dataset(
dataset_dir=FLAGS.dataset_dir,
dataset_name=FLAGS.dataset_name,
split=FLAGS.split,
scale=FLAGS.dataset_scale,
batch_size=1,
shuffle=False)
model = model_lib.NeRF(
ray_samples_coarse=FLAGS.ray_samples_coarse,
ray_samples_fine=FLAGS.ray_samples_fine,
near=FLAGS.near,
far=FLAGS.far,
n_freq_posenc_xyz=FLAGS.n_freq_posenc_xyz,
scene_bbox=tuple([float(x) for x in FLAGS.scene_bbox.split(',')]),
n_freq_posenc_dir=FLAGS.n_freq_posenc_dir,
n_filters=FLAGS.n_filters,
white_background=True)
model.init_coarse_and_fine_models()
model.init_optimizer(learning_rate=FLAGS.learning_rate)
model.init_checkpoint(checkpoint_dir=FLAGS.checkpoint_dir)
if not tf.io.gfile.exists(FLAGS.output_dir):
tf.io.gfile.makedirs(FLAGS.output_dir)
summary_writer = tf.summary.create_file_writer(FLAGS.output_dir)
# ----------------------------------------------------------------------------
current_evaluation = 0
current_checkpoint = ''
while True:
latest_checkpoint = tf.train.latest_checkpoint(FLAGS.checkpoint_dir)
if latest_checkpoint is None:
continue
if current_checkpoint == latest_checkpoint:
continue
current_checkpoint = latest_checkpoint
model.load_checkpoint(current_checkpoint)
total_psnr = []
total_ssim = []
image_counter = 0
for image, focal, principal_point, transform_matrix in dataset:
img_rays, _ = perspective.random_patches(
focal,
principal_point,
height,
width,
patch_height=height,
patch_width=width,
scale=1.0)
# Batchify the image to fit into memory
batch_rays = tf.split(img_rays, height, axis=1)
output = []
for random_rays in batch_rays:
random_rays = utils.change_coordinate_system(random_rays,
(0., 0., 0.),
(1., -1., -1.))
rays_org, rays_dir = utils.camera_rays_from_transformation_matrix(
random_rays,
transform_matrix)
rgb_fine, *_ = model.inference(rays_org, rays_dir)
output.append(rgb_fine)
final_image = tf.concat(output, axis=0)
final_image_np = final_image.numpy()
image_rgb_no_alpha, image_a = tf.split(image, [3, 1], axis=-1)
if FLAGS.white_background:
image = image_rgb_no_alpha * image_a + 1 - image_a
image_np = image.numpy()[0]
ssim = metrics.structural_similarity(image_np,
final_image_np,
multichannel=True,
data_range=1)
psnr = metrics.peak_signal_noise_ratio(image_np,
final_image_np,
data_range=1)
total_psnr.append(psnr)
total_ssim.append(ssim)
filename = os.path.join(FLAGS.output_dir,
'{0:05d}.png'.format(image_counter))
img_to_save = Image.fromarray((final_image_np*255).astype(np.uint8))
with tf.io.gfile.GFile(filename, 'wb') as f:
img_to_save.save(f)
logging.info('Image %d: ssim %.3f / psnr: %.3f',
image_counter, ssim, psnr)
image_counter += 1
# Show some images
if image_counter < 5:
with summary_writer.as_default():
tf.summary.image('rgb_fine/{0}'.format(image_counter),
tf.expand_dims(final_image, 0),
step=current_evaluation,
max_outputs=4)
with summary_writer.as_default():
tf.summary.scalar('eval_ssim', np.mean(total_ssim),
step=current_evaluation)
tf.summary.scalar('eval_psnr', np.mean(total_psnr),
step=current_evaluation)
logging.info('ssim %.3f', np.mean(total_ssim))
logging.info('psnr %.3f', np.mean(total_psnr))
current_evaluation += 1
if FLAGS.single_eval:
break