in free_viewpoint_rendering.py [0:0]
def free_viewpoint_rendering(args):
# memory vs. speed and quality
frames_at_a_time = 10 # set to 1 to reduce memory requirements
only_rgb = False # set to True to reduce memory requirements. Needs to be False for some scene editing to work.
# determine output name
if args.camera_path == "spiral":
output_name = args.deformations + "_" + args.camera_path
elif args.camera_path == "fixed":
output_name = (
args.deformations + "_" + args.camera_path + "_" + str(args.fixed_view)
)
elif args.camera_path == "input_reconstruction":
output_name = args.deformations + "_" + args.camera_path
else:
raise RuntimeError("invalid --camera_path argument")
if args.forced_background_stabilization is not None:
output_name += "_fbs_" + str(args.forced_background_stabilization)
if args.motion_factor is not None:
output_name += "_exaggeration_" + str(args.motion_factor)
if args.foreground_removal is not None:
output_name += "_removal_" + str(args.foreground_removal)
if args.render_canonical:
output_name += "_canonical"
output_folder = os.path.join(args.input, "output", output_name)
create_folder(output_folder)
# load Nerf network
(
render_kwargs_train,
render_kwargs_test,
start,
grad_vars,
load_weights_into_network,
checkpoint_dict,
get_training_ray_bending_latents,
load_llff_dataset,
raw_render_path,
render_convenient,
convert_rgb_to_saveable,
convert_disparity_to_saveable,
convert_disparity_to_jet,
convert_disparity_to_phong,
store_ray_bending_mesh_visualization,
to8b,
) = _setup_nonrigid_nerf_network(args.input)
print("sucessfully loaded nerf network", flush=True)
# load dataset
ray_bending_latents = (
get_training_ray_bending_latents()
) # shape: frames x latent_size
(
images,
poses,
all_rotations,
all_translations,
bds,
render_poses,
render_rotations,
render_translations,
i_train,
i_val,
i_test,
near,
far,
dataset_extras,
) = load_llff_dataset(
render_kwargs_train_=render_kwargs_train, render_kwargs_test_=render_kwargs_test
) # load dataset that this nerf was trained on
print("sucessfully loaded dataset", flush=True)
# determine subset
if args.deformations == "train":
indices = i_train
poses = poses[i_train]
ray_bending_latents = ray_bending_latents[i_train]
images = images[i_train]
print("rendering training set")
elif args.deformations == "test":
indices = i_test
poses = poses[i_test]
ray_bending_latents = ray_bending_latents[i_test]
images = images[i_test]
print("rendering test set")
elif args.deformations == "all":
print("rendering training and test set")
else:
raise RuntimeError("invalid --deformations argument")
copy_over_groundtruth_images = False
if copy_over_groundtruth_images:
groundtruth_images_folder = os.path.join(output_folder, "groundtruth")
create_folder(groundtruth_images_folder)
for i, rgb in enumerate(images):
rgb = convert_rgb_to_saveable(rgb)
file_prefix = os.path.join(groundtruth_images_folder, str(i).zfill(6))
imageio.imwrite(file_prefix + ".png", rgb)
# determine camera poses and latent codes
num_poses = poses.shape[0]
intrinsics = dataset_extras["intrinsics"]
if args.camera_path == "input_reconstruction":
poses = poses
intrinsics = [ intrinsics[dataset_extras["imageid_to_viewid"][i]] for i in range(num_poses) ]
elif args.camera_path == "fixed":
poses = torch.stack(
[torch.Tensor(poses[args.fixed_view]) for _ in range(num_poses)], 0
) # N x 3 x 4
intrinsics = [ intrinsics[dataset_extras["imageid_to_viewid"][args.fixed_view]] for _ in range(num_poses) ]
elif args.camera_path == "spiral":
# poses = np.stack(_spiral_poses(poses, bds, num_poses), axis=0)
poses = []
while len(poses) < num_poses:
poses += [render_pose for render_pose in render_poses]
poses = np.stack(poses, axis=0)[:num_poses]
intrinsics = [ intrinsics[dataset_extras["imageid_to_viewid"][0]] for _ in range(num_poses) ]
else:
# poses has shape N x ... and ray_bending_latents has shape N x ...
# Can design custom camera paths here.
# poses is indexed with imageid
# ray_bending_latents is indexed with timestepid
# intrinsics is indexed with viewid
# images is indexed with imageid
raise RuntimeError
# example with time interpolation from a fixed camera view
num_target_frames = 500
latent_indices = np.linspace(0, ray_bending_latents.shape[0]-1, num=num_target_frames)
start_indices = np.floor(latent_indices).astype(np.int)
end_indices = np.ceil(latent_indices).astype(np.int)
start_latents = ray_bending_latents[start_indices] # num_target_frames x latent_size
end_latents = ray_bending_latents[end_indices] # num_target_frames x latent_size
interpolation_factors = latent_indices - start_indices # shape: num_target_frames. should be in [0,1]
interpolation_factors = torch.Tensor(interpolation_factors).reshape(-1,1) # num_target_frames x 1
ray_bending_latents = end_latents * interpolation_factors + start_latents * (1.-interpolation_factors)
fixed_camera = 0
poses = torch.stack(
[torch.Tensor(poses[fixed_camera]) for _ in range(num_target_frames)], 0
) # N x 3 x 4
intrinsics = [ intrinsics[dataset_extras["imageid_to_viewid"][fixed_camera]] for _ in range(num_target_frames) ]
latents = ray_bending_latents
latents = latents.detach().cuda()
# rendering
correspondence_rgbs = []
rigidities = []
rgbs = []
disps = []
num_output_frames = poses.shape[0]
for start_index in range(0, num_output_frames, frames_at_a_time):
end_index = np.min([start_index + frames_at_a_time, num_output_frames])
print(
"rendering "
+ str(start_index)
+ " to "
+ str(end_index)
+ " out of "
+ str(num_output_frames),
flush=True,
)
subposes = poses[start_index:end_index]
sublatents = [latents[i] for i in range(start_index, end_index)]
# render
returned = render_convenient(
poses=subposes,
ray_bending_latents=sublatents,
intrinsics=intrinsics,
with_ray_bending=not args.render_canonical,
detailed_output=not only_rgb,
rigidity_test_time_cutoff=args.forced_background_stabilization,
motion_factor=args.motion_factor,
foreground_removal=args.foreground_removal
)
if only_rgb:
subrgbs, subdisps = returned
else:
subrgbs, subdisps, details_and_rest = returned
print("finished rendering", flush=True)
rgbs += [image for image in subrgbs]
disps += [image for image in subdisps]
if only_rgb:
correspondence_rgbs += [ None for _ in subrgbs]
rigidities += [ None for _ in subrgbs]
continue
# determine correspondences
# details_and_rest: list, one entry per image. each image has first two dimensions height x width.
min_point = np.array(
checkpoint_dict["scripts_dict"]["min_nerf_volume_point"]
).reshape(1, 1, 3)
max_point = np.array(
checkpoint_dict["scripts_dict"]["max_nerf_volume_point"]
).reshape(1, 1, 3)
for i, image_details in enumerate(details_and_rest):
# visibility_weight is the weight of the influence that each sample has on the final rgb value. so they sum to at most 1.
accumulated_visibility = torch.cumsum(
torch.Tensor(image_details["fine_visibility_weights"]).cuda(), dim=-1
) # height x width x point samples
median_indices = torch.min(torch.abs(accumulated_visibility - 0.5), dim=-1)[
1
] # height x width. visibility goes from 0 to 1. 0.5 is the median, so treat it as "most likely to be on the actually visible surface"
# visualize canonical correspondences as RGB
height, width = median_indices.shape
surface_pixels = (
image_details["fine_input_pts"]
.reshape(height * width, -1, 3)[
np.arange(height * width), median_indices.cpu().reshape(-1), :
]
.reshape(height, width, 3)
) # height x width x 3. median_indices contains the index of one ray sample for each pixel. this ray sample is selected in this line of code.
correspondence_rgb = (surface_pixels - min_point) / (max_point - min_point)
number_of_small_rgb_voxels = 100 # break the canonical space into smaller voxels. each voxel covers the entire RGB space [0,1]^3. makes it easier to visualize small changes. leads to a 3D checkerboard pattern
if number_of_small_rgb_voxels > 1:
correspondence_rgb *= number_of_small_rgb_voxels
correspondence_rgb = correspondence_rgb - correspondence_rgb.astype(int)
correspondence_rgbs.append(correspondence_rgb)
# visualize rigidity
if "fine_rigidity_mask" in image_details:
rigidity = (
image_details["fine_rigidity_mask"]
.reshape(height * width, -1)[
np.arange(height * width), median_indices.cpu().reshape(-1)
]
.reshape(height, width)
) # height x width. values in [0,1]
rigidities.append(rigidity)
else:
rigidities.append(None)
rgbs = np.stack(rgbs, axis=0)
disps = np.stack(disps, axis=0)
correspondence_rgbs = np.stack(correspondence_rgbs, axis=0)
use_rigidity = rigidities[0] is not None
# store results
# for i, (rgb, disp, correspondence_rgb, rigidity) in zip(indices, (zip(rgbs, disps, correspondence_rgbs, rigidities))):
for i, (rgb, disp, correspondence_rgb, rigidity) in enumerate(
zip(rgbs, disps, correspondence_rgbs, rigidities)
):
print("storing image " + str(i) + " / " + str(rgbs.shape[0]), flush=True)
rgb = convert_rgb_to_saveable(rgb)
disp_saveable = convert_disparity_to_saveable(disp)
disp_jet = convert_disparity_to_jet(disp)
disp_phong = convert_disparity_to_phong(disp)
if not only_rgb:
correspondence_rgb = convert_rgb_to_saveable(correspondence_rgb)
if use_rigidity:
rigidity_saveable = convert_disparity_to_saveable(rigidity, normalize=False)
rigidity_jet = convert_disparity_to_jet(rigidity, normalize=False)
file_postfix = "_" + str(i).zfill(6) + ".png"
imageio.imwrite(os.path.join(output_folder, "rgb" + file_postfix), rgb)
if not only_rgb:
imageio.imwrite(
os.path.join(output_folder, "correspondences" + file_postfix),
correspondence_rgb,
)
if use_rigidity:
imageio.imwrite(
os.path.join(output_folder, "rigidity" + file_postfix),
rigidity_saveable,
)
imageio.imwrite(
os.path.join(output_folder, "rigidity_jet" + file_postfix), rigidity_jet
)
imageio.imwrite(
os.path.join(output_folder, "disp" + file_postfix), disp_saveable
)
imageio.imwrite(
os.path.join(output_folder, "disp_jet" + file_postfix), disp_jet
)
imageio.imwrite(
os.path.join(output_folder, "disp_phong" + file_postfix), disp_phong
)
# movies
file_prefix = os.path.join(output_folder, "video_")
try:
print("storing RGB video...", flush=True)
imageio.mimwrite(
file_prefix + "rgb.mp4",
convert_rgb_to_saveable(rgbs),
fps=args.output_video_fps,
quality=9,
)
if not only_rgb:
print("storing correspondence RGB video...", flush=True)
imageio.mimwrite(
file_prefix + "correspondences.mp4",
convert_rgb_to_saveable(correspondence_rgbs),
fps=args.output_video_fps,
quality=9,
)
print("storing disparity video...", flush=True)
imageio.mimwrite(
file_prefix + "disp.mp4",
convert_disparity_to_saveable(disps),
fps=args.output_video_fps,
quality=9,
)
print("storing disparity jet video...", flush=True)
imageio.mimwrite(
file_prefix + "disp_jet.mp4",
np.stack([convert_disparity_to_jet(disp) for disp in disps], axis=0),
fps=args.output_video_fps,
quality=9,
)
print("storing disparity phong video...", flush=True)
imageio.mimwrite(
file_prefix + "disp_phong.mp4",
np.stack([convert_disparity_to_phong(disp) for disp in disps], axis=0),
fps=args.output_video_fps,
quality=9,
)
if use_rigidity:
rigidities = np.stack(rigidities, axis=0)
print("storing rigidity video...", flush=True)
imageio.mimwrite(
file_prefix + "rigidity.mp4",
convert_disparity_to_saveable(rigidities, normalize=False),
fps=args.output_video_fps,
quality=9,
)
print("storing rigidity jet video...", flush=True)
imageio.mimwrite(
file_prefix + "rigidity_jet.mp4",
np.stack(
[
convert_disparity_to_jet(rigidity, normalize=False)
for rigidity in rigidities
],
axis=0,
),
fps=args.output_video_fps,
quality=9,
)
except:
print("imageio.mimwrite() failed. maybe ffmpeg is not installed properly?")
# evaluation of background stability
if args.camera_path == "fixed":
standard_deviations = np.std(rgbs, axis=0)
averaged_standard_devations = 10 * np.mean(standard_deviations, axis=-1)
from matplotlib import cm
color_mapping = np.array([ cm.jet(i)[:3] for i in range(256) ])
max_value = 1
min_value = 0
averaged_standard_devations = np.clip(averaged_standard_devations, a_max=max_value, a_min=min_value) / max_value # cut off above max_value. result is normalized to [0,1]
averaged_standard_devations = (255. * averaged_standard_devations).astype('uint8') # now contains int in [0,255]
original_shape = averaged_standard_devations.shape
averaged_standard_devations = color_mapping[averaged_standard_devations.flatten()]
averaged_standard_devations = averaged_standard_devations.reshape(original_shape + (3,))
imageio.imwrite(os.path.join(output_folder, "standard_deviations.png"), averaged_standard_devations)
# quantitative evaluation
if args.camera_path == "input_reconstruction":
try:
from PerceptualSimilarity import lpips
perceptual_metric = lpips.LPIPS(net='alex')
except:
print("Perceptual LPIPS metric not found. Please see the README for installation instructions")
perceptual_metric = None
create_error_maps = True # whether to write out error images instead of just computing scores
naive_error_folder = os.path.join(output_folder, "naive_errors/")
create_folder(naive_error_folder)
ssim_error_folder = os.path.join(output_folder, "ssim_errors/")
create_folder(ssim_error_folder)
to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
def visualize_with_jet_color_scheme(image):
from matplotlib import cm
color_mapping = np.array([ cm.jet(i)[:3] for i in range(256) ])
max_value = 1.0
min_value = 0.0
intermediate = np.clip(image, a_max=max_value, a_min=min_value) / max_value # cut off above max_value. result is normalized to [0,1]
intermediate = (255. * intermediate).astype('uint8') # now contains int in [0,255]
original_shape = intermediate.shape
intermediate = color_mapping[intermediate.flatten()]
intermediate = intermediate.reshape(original_shape + (3,))
return intermediate
mask = None
scores = {}
from skimage.metrics import structural_similarity as ssim
for i, (groundtruth, generated) in enumerate(zip(images, rgbs)):
if mask is None: # undistortion leads to masked-out black pixels in groundtruth
mask = (np.sum(groundtruth, axis=-1) == 0.)
groundtruth[mask] = 0.
generated[mask] = 0.
# PSNR
mse = np.mean((groundtruth - generated) ** 2)
psnr = -10. * np.log10(mse)
# SSIM
# https://scikit-image.org/docs/dev/api/skimage.metrics.html#skimage.metrics.structural_similarity
returned = ssim(groundtruth, generated, data_range=1.0, multichannel=True, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, full=create_error_maps)
if create_error_maps:
ssim_error, ssim_error_image = returned
else:
ssim_error = returned
# perceptual metric
if perceptual_metric is None:
lpips = 1.
else:
def numpy_to_pytorch(np_image):
torch_image = 2 * torch.from_numpy(np_image) - 1 # height x width x 3. must be in [-1,+1]
torch_image = torch_image.permute(2, 0, 1) # 3 x height x width
return torch_image.unsqueeze(0) # 1 x 3 x height x width
lpips = perceptual_metric.forward(numpy_to_pytorch(groundtruth), numpy_to_pytorch(generated))
lpips = float(lpips.detach().reshape(1).numpy()[0])
scores[i] = {"psnr": psnr, "ssim": ssim_error, "lpips": lpips}
if create_error_maps:
# MSE-style
error = np.linalg.norm(groundtruth - generated, axis=-1) / np.sqrt(1+1+1) # height x width
error *= 10. # exaggarate error
error = np.clip(error, 0.0, 1.0)
error = to8b(visualize_with_jet_color_scheme(error)) # height x width x 3. int values in [0,255]
filename = os.path.join(naive_error_folder, 'error_{:03d}.png'.format(i))
imageio.imwrite(filename, error)
# SSIM
filename = os.path.join(ssim_error_folder, 'error_{:03d}.png'.format(i))
ssim_error_image = to8b(visualize_with_jet_color_scheme(1.-np.mean(ssim_error_image, axis=-1)))
imageio.imwrite(filename, ssim_error_image)
averaged_scores = {}
averaged_scores["average_psnr"] = np.mean([ score["psnr"] for score in scores.values() ])
averaged_scores["average_ssim"] = np.mean([ score["ssim"] for score in scores.values() ])
averaged_scores["average_lpips"] = np.mean([ score["lpips"] for score in scores.values() ])
print(averaged_scores, flush=True)
scores.update(averaged_scores)
import json
with open(os.path.join(output_folder, "scores.json"), "w") as json_file:
json.dump(scores, json_file, indent=4)