def free_viewpoint_rendering()

in free_viewpoint_rendering.py [0:0]


def free_viewpoint_rendering(args):

    # memory vs. speed and quality
    frames_at_a_time = 10 # set to 1 to reduce memory requirements
    only_rgb = False # set to True to reduce memory requirements. Needs to be False for some scene editing to work.

    # determine output name
    if args.camera_path == "spiral":
        output_name = args.deformations + "_" + args.camera_path
    elif args.camera_path == "fixed":
        output_name = (
            args.deformations + "_" + args.camera_path + "_" + str(args.fixed_view)
        )
    elif args.camera_path == "input_reconstruction":
        output_name = args.deformations + "_" + args.camera_path
    else:
        raise RuntimeError("invalid --camera_path argument")
        
    if args.forced_background_stabilization is not None:
        output_name += "_fbs_" + str(args.forced_background_stabilization)
    if args.motion_factor is not None:
        output_name += "_exaggeration_" + str(args.motion_factor)
    if args.foreground_removal is not None:
        output_name += "_removal_" + str(args.foreground_removal)
    if args.render_canonical:
        output_name += "_canonical"

    output_folder = os.path.join(args.input, "output", output_name)
    create_folder(output_folder)

    # load Nerf network
    (
        render_kwargs_train,
        render_kwargs_test,
        start,
        grad_vars,
        load_weights_into_network,
        checkpoint_dict,
        get_training_ray_bending_latents,
        load_llff_dataset,
        raw_render_path,
        render_convenient,
        convert_rgb_to_saveable,
        convert_disparity_to_saveable,
        convert_disparity_to_jet,
        convert_disparity_to_phong,
        store_ray_bending_mesh_visualization,
        to8b,
    ) = _setup_nonrigid_nerf_network(args.input)
    print("sucessfully loaded nerf network", flush=True)

    # load dataset
    ray_bending_latents = (
        get_training_ray_bending_latents()
    )  # shape: frames x latent_size
    (
        images,
        poses,
        all_rotations,
        all_translations,
        bds,
        render_poses,
        render_rotations,
        render_translations,
        i_train,
        i_val,
        i_test,
        near,
        far,
        dataset_extras,
    ) = load_llff_dataset(
        render_kwargs_train_=render_kwargs_train, render_kwargs_test_=render_kwargs_test
    )  # load dataset that this nerf was trained on
    print("sucessfully loaded dataset", flush=True)

    # determine subset
    if args.deformations == "train":
        indices = i_train
        poses = poses[i_train]
        ray_bending_latents = ray_bending_latents[i_train]
        images = images[i_train]
        print("rendering training set")
    elif args.deformations == "test":
        indices = i_test
        poses = poses[i_test]
        ray_bending_latents = ray_bending_latents[i_test]
        images = images[i_test]
        print("rendering test set")
    elif args.deformations == "all":
        print("rendering training and test set")
    else:
        raise RuntimeError("invalid --deformations argument")

    copy_over_groundtruth_images = False
    if copy_over_groundtruth_images:
        groundtruth_images_folder = os.path.join(output_folder, "groundtruth")
        create_folder(groundtruth_images_folder)
        for i, rgb in enumerate(images):
            rgb = convert_rgb_to_saveable(rgb)
            file_prefix = os.path.join(groundtruth_images_folder, str(i).zfill(6))
            imageio.imwrite(file_prefix + ".png", rgb)

    # determine camera poses and latent codes
    num_poses = poses.shape[0]
    intrinsics = dataset_extras["intrinsics"]
    if args.camera_path == "input_reconstruction":
        poses = poses
        intrinsics = [ intrinsics[dataset_extras["imageid_to_viewid"][i]] for i in range(num_poses) ]
    elif args.camera_path == "fixed":
        poses = torch.stack(
            [torch.Tensor(poses[args.fixed_view]) for _ in range(num_poses)], 0
        )  # N x 3 x 4
        intrinsics = [ intrinsics[dataset_extras["imageid_to_viewid"][args.fixed_view]] for _ in range(num_poses) ]
    elif args.camera_path == "spiral":
        # poses = np.stack(_spiral_poses(poses, bds, num_poses), axis=0)
        poses = []
        while len(poses) < num_poses:
            poses += [render_pose for render_pose in render_poses]
        poses = np.stack(poses, axis=0)[:num_poses]
        intrinsics = [ intrinsics[dataset_extras["imageid_to_viewid"][0]] for _ in range(num_poses) ]
    else:
        # poses has shape N x ... and ray_bending_latents has shape N x ...
        # Can design custom camera paths here.
        # poses is indexed with imageid
        # ray_bending_latents is indexed with timestepid
        # intrinsics is indexed with viewid
        # images is indexed with imageid
        raise RuntimeError
        
        # example with time interpolation from a fixed camera view
        num_target_frames = 500
        latent_indices = np.linspace(0, ray_bending_latents.shape[0]-1, num=num_target_frames)
        start_indices = np.floor(latent_indices).astype(np.int)
        end_indices = np.ceil(latent_indices).astype(np.int)
        start_latents = ray_bending_latents[start_indices] # num_target_frames x latent_size
        end_latents = ray_bending_latents[end_indices] # num_target_frames x latent_size
        interpolation_factors = latent_indices - start_indices # shape: num_target_frames. should be in [0,1]
        interpolation_factors = torch.Tensor(interpolation_factors).reshape(-1,1) # num_target_frames x 1
        ray_bending_latents = end_latents * interpolation_factors + start_latents * (1.-interpolation_factors)
        
        fixed_camera = 0
        poses = torch.stack(
            [torch.Tensor(poses[fixed_camera]) for _ in range(num_target_frames)], 0
        )  # N x 3 x 4
        intrinsics = [ intrinsics[dataset_extras["imageid_to_viewid"][fixed_camera]] for _ in range(num_target_frames) ]
        
    latents = ray_bending_latents

    latents = latents.detach().cuda()

    # rendering
    correspondence_rgbs = []
    rigidities = []
    rgbs = []
    disps = []

    num_output_frames = poses.shape[0]
    for start_index in range(0, num_output_frames, frames_at_a_time):

        end_index = np.min([start_index + frames_at_a_time, num_output_frames])
        print(
            "rendering "
            + str(start_index)
            + " to "
            + str(end_index)
            + " out of "
            + str(num_output_frames),
            flush=True,
        )

        subposes = poses[start_index:end_index]
        sublatents = [latents[i] for i in range(start_index, end_index)]

        # render
        returned = render_convenient(
            poses=subposes,
            ray_bending_latents=sublatents,
            intrinsics=intrinsics,
            with_ray_bending=not args.render_canonical,
            detailed_output=not only_rgb,
            rigidity_test_time_cutoff=args.forced_background_stabilization,
            motion_factor=args.motion_factor,
            foreground_removal=args.foreground_removal
        )
        if only_rgb:
            subrgbs, subdisps = returned
        else:
            subrgbs, subdisps, details_and_rest = returned
        print("finished rendering", flush=True)

        rgbs += [image for image in subrgbs]
        disps += [image for image in subdisps]
        if only_rgb:
            correspondence_rgbs += [ None for _ in subrgbs]
            rigidities += [ None for _ in subrgbs]
            continue

        # determine correspondences
        # details_and_rest: list, one entry per image. each image has first two dimensions height x width.
        min_point = np.array(
            checkpoint_dict["scripts_dict"]["min_nerf_volume_point"]
        ).reshape(1, 1, 3)
        max_point = np.array(
            checkpoint_dict["scripts_dict"]["max_nerf_volume_point"]
        ).reshape(1, 1, 3)
        for i, image_details in enumerate(details_and_rest):
            # visibility_weight is the weight of the influence that each sample has on the final rgb value. so they sum to at most 1.
            accumulated_visibility = torch.cumsum(
                torch.Tensor(image_details["fine_visibility_weights"]).cuda(), dim=-1
            )  # height x width x point samples
            median_indices = torch.min(torch.abs(accumulated_visibility - 0.5), dim=-1)[
                1
            ]  # height x width. visibility goes from 0 to 1. 0.5 is the median, so treat it as "most likely to be on the actually visible surface"
            # visualize canonical correspondences as RGB
            height, width = median_indices.shape
            surface_pixels = (
                image_details["fine_input_pts"]
                .reshape(height * width, -1, 3)[
                    np.arange(height * width), median_indices.cpu().reshape(-1), :
                ]
                .reshape(height, width, 3)
            )  # height x width x 3. median_indices contains the index of one ray sample for each pixel. this ray sample is selected in this line of code.
            correspondence_rgb = (surface_pixels - min_point) / (max_point - min_point)
            number_of_small_rgb_voxels = 100  # break the canonical space into smaller voxels. each voxel covers the entire RGB space [0,1]^3. makes it easier to visualize small changes. leads to a 3D checkerboard pattern
            if number_of_small_rgb_voxels > 1:
                correspondence_rgb *= number_of_small_rgb_voxels
                correspondence_rgb = correspondence_rgb - correspondence_rgb.astype(int)
            correspondence_rgbs.append(correspondence_rgb)

            # visualize rigidity
            if "fine_rigidity_mask" in image_details:
                rigidity = (
                    image_details["fine_rigidity_mask"]
                    .reshape(height * width, -1)[
                        np.arange(height * width), median_indices.cpu().reshape(-1)
                    ]
                    .reshape(height, width)
                )  # height x width. values in [0,1]
                rigidities.append(rigidity)
            else:
                rigidities.append(None)

    rgbs = np.stack(rgbs, axis=0)
    disps = np.stack(disps, axis=0)
    correspondence_rgbs = np.stack(correspondence_rgbs, axis=0)
    use_rigidity = rigidities[0] is not None

    # store results
    # for i, (rgb, disp, correspondence_rgb, rigidity) in zip(indices, (zip(rgbs, disps, correspondence_rgbs, rigidities))):
    for i, (rgb, disp, correspondence_rgb, rigidity) in enumerate(
        zip(rgbs, disps, correspondence_rgbs, rigidities)
    ):
        print("storing image " + str(i) + " / " + str(rgbs.shape[0]), flush=True)
        rgb = convert_rgb_to_saveable(rgb)
        disp_saveable = convert_disparity_to_saveable(disp)
        disp_jet = convert_disparity_to_jet(disp)
        disp_phong = convert_disparity_to_phong(disp)
        if not only_rgb:
            correspondence_rgb = convert_rgb_to_saveable(correspondence_rgb)
        if use_rigidity:
            rigidity_saveable = convert_disparity_to_saveable(rigidity, normalize=False)
            rigidity_jet = convert_disparity_to_jet(rigidity, normalize=False)

        file_postfix = "_" + str(i).zfill(6) + ".png"
        imageio.imwrite(os.path.join(output_folder, "rgb" + file_postfix), rgb)
        if not only_rgb:
            imageio.imwrite(
                os.path.join(output_folder, "correspondences" + file_postfix),
                correspondence_rgb,
            )
        if use_rigidity:
            imageio.imwrite(
                os.path.join(output_folder, "rigidity" + file_postfix),
                rigidity_saveable,
            )
            imageio.imwrite(
                os.path.join(output_folder, "rigidity_jet" + file_postfix), rigidity_jet
            )
        imageio.imwrite(
            os.path.join(output_folder, "disp" + file_postfix), disp_saveable
        )
        imageio.imwrite(
            os.path.join(output_folder, "disp_jet" + file_postfix), disp_jet
        )
        imageio.imwrite(
            os.path.join(output_folder, "disp_phong" + file_postfix), disp_phong
        )

    # movies
    file_prefix = os.path.join(output_folder, "video_")
    try:
        print("storing RGB video...", flush=True)
        imageio.mimwrite(
            file_prefix + "rgb.mp4",
            convert_rgb_to_saveable(rgbs),
            fps=args.output_video_fps,
            quality=9,
        )
        if not only_rgb:
            print("storing correspondence RGB video...", flush=True)
            imageio.mimwrite(
                file_prefix + "correspondences.mp4",
                convert_rgb_to_saveable(correspondence_rgbs),
                fps=args.output_video_fps,
                quality=9,
            )
        print("storing disparity video...", flush=True)
        imageio.mimwrite(
            file_prefix + "disp.mp4",
            convert_disparity_to_saveable(disps),
            fps=args.output_video_fps,
            quality=9,
        )
        print("storing disparity jet video...", flush=True)
        imageio.mimwrite(
            file_prefix + "disp_jet.mp4",
            np.stack([convert_disparity_to_jet(disp) for disp in disps], axis=0),
            fps=args.output_video_fps,
            quality=9,
        )
        print("storing disparity phong video...", flush=True)
        imageio.mimwrite(
            file_prefix + "disp_phong.mp4",
            np.stack([convert_disparity_to_phong(disp) for disp in disps], axis=0),
            fps=args.output_video_fps,
            quality=9,
        )
        if use_rigidity:
            rigidities = np.stack(rigidities, axis=0)
            print("storing rigidity video...", flush=True)
            imageio.mimwrite(
                file_prefix + "rigidity.mp4",
                convert_disparity_to_saveable(rigidities, normalize=False),
                fps=args.output_video_fps,
                quality=9,
            )
            print("storing rigidity jet video...", flush=True)
            imageio.mimwrite(
                file_prefix + "rigidity_jet.mp4",
                np.stack(
                    [
                        convert_disparity_to_jet(rigidity, normalize=False)
                        for rigidity in rigidities
                    ],
                    axis=0,
                ),
                fps=args.output_video_fps,
                quality=9,
            )
    except:
        print("imageio.mimwrite() failed. maybe ffmpeg is not installed properly?")
        
    # evaluation of background stability
    if args.camera_path == "fixed":
        standard_deviations = np.std(rgbs, axis=0)
        averaged_standard_devations = 10 * np.mean(standard_deviations, axis=-1)
        
        from matplotlib import cm
        color_mapping = np.array([ cm.jet(i)[:3] for i in range(256) ])
        max_value = 1
        min_value = 0
        averaged_standard_devations = np.clip(averaged_standard_devations, a_max=max_value, a_min=min_value) / max_value # cut off above max_value. result is normalized to [0,1]
        averaged_standard_devations = (255. * averaged_standard_devations).astype('uint8') # now contains int in [0,255]
        original_shape = averaged_standard_devations.shape
        averaged_standard_devations = color_mapping[averaged_standard_devations.flatten()]
        averaged_standard_devations = averaged_standard_devations.reshape(original_shape + (3,))
        
        imageio.imwrite(os.path.join(output_folder, "standard_deviations.png"), averaged_standard_devations)

    # quantitative evaluation
    if args.camera_path == "input_reconstruction":
        try:
            from PerceptualSimilarity import lpips
            perceptual_metric = lpips.LPIPS(net='alex')
        except:
            print("Perceptual LPIPS metric not found. Please see the README for installation instructions")
            perceptual_metric = None
            
        create_error_maps = True # whether to write out error images instead of just computing scores
    
        naive_error_folder = os.path.join(output_folder, "naive_errors/")
        create_folder(naive_error_folder)
        ssim_error_folder = os.path.join(output_folder, "ssim_errors/")
        create_folder(ssim_error_folder)
        
        to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8)
        def visualize_with_jet_color_scheme(image):
            from matplotlib import cm
            color_mapping = np.array([ cm.jet(i)[:3] for i in range(256) ])
            max_value = 1.0
            min_value = 0.0
            intermediate = np.clip(image, a_max=max_value, a_min=min_value) / max_value # cut off above max_value. result is normalized to [0,1]
            intermediate = (255. * intermediate).astype('uint8') # now contains int in [0,255]
            original_shape = intermediate.shape
            intermediate = color_mapping[intermediate.flatten()]
            intermediate = intermediate.reshape(original_shape + (3,))
            return intermediate

        mask = None
        scores = {}
        from skimage.metrics import structural_similarity as ssim
        for i, (groundtruth, generated) in enumerate(zip(images, rgbs)):   

            if mask is None: # undistortion leads to masked-out black pixels in groundtruth
                mask = (np.sum(groundtruth, axis=-1) == 0.)
            groundtruth[mask] = 0.
            generated[mask] = 0.
            
            # PSNR
            mse = np.mean((groundtruth - generated) ** 2)
            psnr = -10. * np.log10(mse)
            
            # SSIM
            # https://scikit-image.org/docs/dev/api/skimage.metrics.html#skimage.metrics.structural_similarity
            returned = ssim(groundtruth, generated, data_range=1.0, multichannel=True, gaussian_weights=True, sigma=1.5, use_sample_covariance=False, full=create_error_maps)
            if create_error_maps:
                ssim_error, ssim_error_image = returned
            else:
                ssim_error = returned
                
            # perceptual metric 
            if perceptual_metric is None:
                lpips = 1.
            else:
                def numpy_to_pytorch(np_image):
                    torch_image = 2 * torch.from_numpy(np_image) - 1 # height x width x 3. must be in [-1,+1]
                    torch_image = torch_image.permute(2, 0, 1) # 3 x height x width
                    return torch_image.unsqueeze(0) # 1 x 3 x height x width
                lpips = perceptual_metric.forward(numpy_to_pytorch(groundtruth), numpy_to_pytorch(generated))
                lpips = float(lpips.detach().reshape(1).numpy()[0])
            
            scores[i] = {"psnr": psnr, "ssim": ssim_error, "lpips": lpips}
            
            if create_error_maps:
                # MSE-style
                error = np.linalg.norm(groundtruth - generated, axis=-1) / np.sqrt(1+1+1) # height x width
                error *= 10. # exaggarate error
                error = np.clip(error, 0.0, 1.0)
                error = to8b(visualize_with_jet_color_scheme(error)) # height x width x 3. int values in [0,255]
                filename = os.path.join(naive_error_folder, 'error_{:03d}.png'.format(i))
                imageio.imwrite(filename, error)
                
                # SSIM
                filename = os.path.join(ssim_error_folder, 'error_{:03d}.png'.format(i))
                ssim_error_image = to8b(visualize_with_jet_color_scheme(1.-np.mean(ssim_error_image, axis=-1)))
                imageio.imwrite(filename, ssim_error_image)
        
        averaged_scores = {}
        averaged_scores["average_psnr"] = np.mean([ score["psnr"] for score in scores.values() ])
        averaged_scores["average_ssim"] = np.mean([ score["ssim"] for score in scores.values() ])
        averaged_scores["average_lpips"] = np.mean([ score["lpips"] for score in scores.values() ])
        
        print(averaged_scores, flush=True)
        
        scores.update(averaged_scores)
        
        import json
        with open(os.path.join(output_folder, "scores.json"), "w") as json_file:
            json.dump(scores, json_file, indent=4)