def _setup_nonrigid_nerf_network()

in free_viewpoint_rendering.py [0:0]


def _setup_nonrigid_nerf_network(results_folder, checkpoint="latest"):

    extra_sys_folder = os.path.join(results_folder, "backup/")

    import sys

    sys.path.append(extra_sys_folder)

    from train import (
        config_parser,
        create_nerf,
        render_path,
        get_parallelized_render_function,
        _get_multi_view_helper_mappings,
    )

    args = config_parser().parse_args(
        ["--config", os.path.join(results_folder, "logs", "args.txt")]
    )

    print(args, flush=True)

    render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(
        args, autodecoder_variables=None, ignore_optimizer=True
    )

    def load_weights_into_network(render_kwargs_train, checkpoint=None, path=None):
        if path is not None and checkpoint is not None:
            raise RuntimeError("trying to load weights from two sources")
        if checkpoint is not None:
            path = os.path.join(results_folder, "logs", checkpoint + ".tar")
        checkpoint_dict = torch.load(path)
        start = checkpoint_dict["global_step"]
        # optimizer.load_state_dict(checkpoint_dict['optimizer_state_dict'])
        render_kwargs_train["network_fn"].load_state_dict(
            checkpoint_dict["network_fn_state_dict"]
        )
        if render_kwargs_train["network_fine"] is not None:
            render_kwargs_train["network_fine"].load_state_dict(
                checkpoint_dict["network_fine_state_dict"]
            )
        if render_kwargs_train["ray_bender"] is not None:
            render_kwargs_train["ray_bender"].load_state_dict(
                checkpoint_dict["ray_bender_state_dict"]
            )
        return checkpoint_dict

    checkpoint_dict = load_weights_into_network(
        render_kwargs_train, checkpoint=checkpoint
    )

    def get_training_ray_bending_latents(checkpoint="latest"):
        training_latent_vectors = os.path.join(
            results_folder, "logs", checkpoint + ".tar"
        )
        training_latent_vectors = torch.load(training_latent_vectors)[
            "ray_bending_latent_codes"
        ]
        return training_latent_vectors  # shape: frames x latent_size

    from run_nerf_helpers import determine_nerf_volume_extent
    from load_llff import load_llff_data

    def load_llff_dataset(
        render_kwargs_train_=None,
        render_kwargs_test_=None,
        return_nerf_volume_extent=False,
    ):

        datadir = args.datadir
        factor = args.factor
        spherify = args.spherify
        bd_factor = args.bd_factor

        # actual loading
        images, poses, bds, render_poses, i_test = load_llff_data(
            datadir,
            factor=factor,
            recenter=True,
            bd_factor=bd_factor,
            spherify=spherify,
        )
        extras = _get_multi_view_helper_mappings(images.shape[0], datatdir)

        # poses
        hwf = poses[0, :3, -1]
        poses = poses[:, :3, :4]  # N x 3 x 4
        all_rotations = poses[:, :3, :3]  # N x 3 x 3
        all_translations = poses[:, :3, 3]  # N x 3

        render_poses = render_poses[:, :3, :4]
        render_rotations = render_poses[:, :3, :3]
        render_translations = render_poses[:, :3, 3]

        # splits
        i_test = []  # [i_test]
        if args.test_block_size > 0 and args.train_block_size > 0:
            print(
                "splitting timesteps into training ("
                + str(args.train_block_size)
                + ") and test ("
                + str(args.test_block_size)
                + ") blocks"
            )
            num_timesteps = len(extras["raw_timesteps"])
            test_timesteps = np.concatenate(
                [
                    np.arange(
                        min(num_timesteps, blocks_start + args.train_block_size),
                        min(
                            num_timesteps,
                            blocks_start + args.train_block_size + args.test_block_size,
                        ),
                    )
                    for blocks_start in np.arange(
                        0, num_timesteps, args.train_block_size + args.test_block_size
                    )
                ]
            )
            i_test = [
                imageid
                for imageid, timestep in enumerate(
                    extras["imageid_to_timestepid"]
                )
                if timestep in test_timesteps
            ]

        i_test = np.array(i_test)
        i_val = i_test
        i_train = np.array(
            [
                i
                for i in np.arange(int(images.shape[0]))
                if (i not in i_test and i not in i_val)
            ]
        )

        # near, far
        # if args.no_ndc:
        near = np.ndarray.min(bds) * 0.9
        far = np.ndarray.max(bds) * 1.0
        # else:
        #    near = 0.
        #    far = 1.
        bds_dict = {
            "near": near,
            "far": far,
        }
        if render_kwargs_train_ is not None:
            render_kwargs_train_.update(bds_dict)
        if render_kwargs_test_ is not None:
            render_kwargs_test_.update(bds_dict)

        if return_nerf_volume_extent:
            intrinsics = checkpoint_dict["intrinsics"]
            min_point, max_point = determine_nerf_volume_extent(
                get_parallelized_render_function(),
                poses,
                [ intrinsics[extras["imageid_to_viewid"][imageid]] for imageid in range(poses.shape[0]) ],
                render_kwargs_test,
                args
            )
            extras["min_nerf_volume_point"] = min_point.detach()
            extras["max_nerf_volume_point"] = max_point.detach()
            
        extras["intrinsics"] = checkpoint_dict["intrinsics"]
        
        return (
            images,
            poses,
            all_rotations,
            all_translations,
            bds,
            render_poses,
            render_rotations,
            render_translations,
            i_train,
            i_val,
            i_test,
            near,
            far,
            extras,
        )

    raw_render_path = render_path

    def render_convenient(
        rotations=None,
        translations=None,
        poses=None,
        detailed_output=None,
        ray_bending_latents=None,
        render_factor=None,
        with_ray_bending=None,
        custom_checkpoint_dict=None,
        intrinsics=None,
        chunk=None,
        custom_ray_params=None,
        custom_render_kwargs_test=None,
        rigidity_test_time_cutoff=None,
        motion_factor=None,
        foreground_removal=None
    ):

        # poses should have shape Nx3x4, rotations Nx3x3, translations Nx3 (or Nx3x1 or Nx1x3 or 3)
        # ray_bending_latents are a list of latent codes or an array of shape N x latent_size
        # intrinsics should be a list of dicts with entries height, width, center_x, center_y, focal_x, focal_y

        # poses
        if poses is None:
            if rotations is None or translations is None:
                raise RuntimeError
            rotations = torch.Tensor(rotations).reshape(-1, 3, 3)
            translations = torch.Tensor(translations).reshape(-1, 3, 1)
            poses = torch.cat([rotations, translations], -1)  # N x 3 x 4
        else:
            if rotations is not None or translations is not None:
                raise RuntimeError
        if len(poses.shape) > 3:
            raise RuntimeError
        if (
            poses.shape[-1] == 5
        ):  # the standard poses that are loaded by load_llff have hwf in the last column, but that's ignored anyway later on, so throw away here for simplicity
            poses = poses[..., :4]
        poses = torch.Tensor(poses).cuda().reshape(-1, 3, 4)

        # other parameters/arguments
        checkpoint_dict_ = (
            checkpoint_dict
            if custom_checkpoint_dict is None
            else custom_checkpoint_dict
        )
        render_kwargs_test_ = (
            render_kwargs_test
            if custom_render_kwargs_test is None
            else custom_render_kwargs_test
        )
        if intrinsics is None:
            intrinsics = [ checkpoint_dict["intrinsics"][0] for _ in range(len(poses)) ]
        if chunk is None:
            chunk = args.chunk
        if render_factor is None:
            render_factor = 0
        if detailed_output is None:
            detailed_output = False
        if with_ray_bending is None:
            with_ray_bending = True

        if with_ray_bending:
            # forced background stabilization
            backup_rigidity_test_time_cutoff = render_kwargs_test_[
                "ray_bender"
            ].rigidity_test_time_cutoff
            render_kwargs_test_[
                "ray_bender"
            ].rigidity_test_time_cutoff = rigidity_test_time_cutoff
            # motion exaggeration/dampening
            backup_test_time_scaling = render_kwargs_test_[
                "ray_bender"
            ].test_time_scaling
            render_kwargs_test_[
                "ray_bender"
            ].test_time_scaling = motion_factor
            # foreground removal
            backup_foreground_removal = render_kwargs_test_["network_fn"].test_time_nonrigid_object_removal_threshold
            render_kwargs_test_["network_fn"].test_time_nonrigid_object_removal_threshold = foreground_removal
            if "network_fine" in render_kwargs_test_:
                render_kwargs_test_["network_fine"].test_time_nonrigid_object_removal_threshold = foreground_removal
        else:
            backup_ray_bender = render_kwargs_test_["network_fn"].ray_bender
            render_kwargs_test_["network_fn"].ray_bender = (None,)
            render_kwargs_test_["ray_bender"] = None
            if "network_fine" in render_kwargs_test_:
                render_kwargs_test_["network_fine"].ray_bender = (None,)

        coarse_model = render_kwargs_test_["network_fn"]
        fine_model = render_kwargs_test_["network_fine"]
        ray_bender = render_kwargs_test_["ray_bender"]
        parallel_render = get_parallelized_render_function(
            coarse_model=coarse_model, fine_model=fine_model, ray_bender=ray_bender
        )
        with torch.no_grad():
            returned_outputs = render_path(
                poses,
                intrinsics,
                args.chunk,
                render_kwargs_test_,
                render_factor=render_factor,
                detailed_output=detailed_output,
                ray_bending_latents=ray_bending_latents,
                parallelized_render_function=parallel_render,
            )

        if with_ray_bending:
            render_kwargs_test_[
                "ray_bender"
            ].rigidity_test_time_cutoff = backup_rigidity_test_time_cutoff
            render_kwargs_test_[
                "ray_bender"
            ].test_time_scaling = backup_test_time_scaling
            render_kwargs_test_["network_fn"].test_time_nonrigid_object_removal_threshold = backup_foreground_removal
            if "network_fine" in render_kwargs_test_:
                render_kwargs_test_["network_fine"].test_time_nonrigid_object_removal_threshold = backup_foreground_removal
        else:
            render_kwargs_test_["network_fn"].ray_bender = backup_ray_bender
            render_kwargs_test_["ray_bender"] = backup_ray_bender[0]
            if "network_fine" in render_kwargs_test_:
                render_kwargs_test_["network_fine"].ray_bender = backup_ray_bender

        if detailed_output:
            rgbs, disps, details_and_rest = returned_outputs
            return (
                rgbs,
                disps,
                details_and_rest,
            )  # N x height x width x 3, N x height x width. RGB values in [0,1]
        else:
            rgbs, disps = returned_outputs
            return (
                rgbs,
                disps,
            )  # N x height x width x 3, N x height x width. RGB values in [0,1]

    from run_nerf_helpers import (
        to8b,
        visualize_disparity_with_jet_color_scheme,
        visualize_disparity_with_blinn_phong,
        visualize_ray_bending,
    )

    def convert_rgb_to_saveable(rgb):
        # input: float values in [0,1]
        # output: int values in [0,255]
        return to8b(rgb)

    def convert_disparity_to_saveable(disparity, normalize=True):
        # takes in a single disparity map of shape height x width.
        # can be saved via: imageio.imwrite(filename, convert_disparity_to_saveable(disparity))
        converted_disparity = (
            disparity / np.max(disparity) if normalize else disparity.copy()
        )
        converted_disparity = to8b(
            converted_disparity
        )  # height x width. int values in [0,255].
        return converted_disparity

    def convert_disparity_to_jet(disparity, normalize=True):
        converted_disparity = (
            disparity / np.max(disparity) if normalize else disparity.copy()
        )
        converted_disparity = to8b(
            visualize_disparity_with_jet_color_scheme(converted_disparity)
        )
        return converted_disparity  # height x width x 3. int values in [0,255].

    def convert_disparity_to_phong(disparity, normalize=True):
        converted_disparity = (
            disparity / np.max(disparity) if normalize else disparity.copy()
        )
        converted_disparity = to8b(
            visualize_disparity_with_blinn_phong(converted_disparity)
        )
        return converted_disparity  # height x width x 3. int values in [0,255].

    def store_ray_bending_mesh_visualization(
        initial_input_pts, input_pts, filename_prefix, subsampled_target=None
    ):
        # initial_input_pts: rays x samples_per_ray x 3
        # input_pts: rays x samples_per_ray x 3
        return visualize_ray_bending(
            initial_input_pts,
            input_pts,
            filename_prefix,
            subsampled_target=subsampled_target,
        )

    sys.path.remove(extra_sys_folder)

    return (
        render_kwargs_train,
        render_kwargs_test,
        start,
        grad_vars,
        load_weights_into_network,
        checkpoint_dict,
        get_training_ray_bending_latents,
        load_llff_dataset,
        raw_render_path,
        render_convenient,
        convert_rgb_to_saveable,
        convert_disparity_to_saveable,
        convert_disparity_to_jet,
        convert_disparity_to_phong,
        store_ray_bending_mesh_visualization,
        to8b,
    )