def create_nerf()

in train.py [0:0]


def create_nerf(args, autodecoder_variables=None, ignore_optimizer=False):
    """Instantiate NeRF's MLP model."""

    grad_vars = []

    if autodecoder_variables is not None:
        grad_vars += autodecoder_variables

    embed_fn, input_ch = get_embedder(args.multires, args.i_embed)

    if args.ray_bending is not None and args.ray_bending != "None":
        ray_bender = ray_bending(
            input_ch, args.ray_bending_latent_size, args.ray_bending, embed_fn
        ).cuda()
        grad_vars += list(ray_bender.parameters())
    else:
        ray_bender = None
        
    if args.time_conditioned_baseline:
        if args.ray_bending == "simple_neural":
            raise RuntimeError("Naive Baseline requires to turn off ray bending")
        if args.offsets_loss_weight > 0. or args.divergence_loss_weight > 0. or args.rigidity_loss_weight > 0.:
            raise RuntimeError("Naive Baseline requires to turn off regularization losses since they only work with ray bending")

    input_ch_views = 0
    embeddirs_fn = None
    if args.use_viewdirs:
        embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
        if args.approx_nonrigid_viewdirs:
            # netchunk needs to be divisible by both number of samples of coarse and fine Nerfs
            def lcm(x, y):
                from math import gcd

                return x * y // gcd(x, y)

            needs_to_divide = lcm(args.N_samples, args.N_samples + args.N_importance)
            args.netchunk = int(args.netchunk / needs_to_divide) * needs_to_divide
    output_ch = 5 if args.N_importance > 0 else 4
    skips = [4]
    model = NeRF(
        D=args.netdepth,
        W=args.netwidth,
        input_ch=input_ch,
        output_ch=output_ch,
        skips=skips,
        input_ch_views=input_ch_views,
        use_viewdirs=args.use_viewdirs,
        ray_bender=ray_bender,
        ray_bending_latent_size=args.ray_bending_latent_size,
        embeddirs_fn=embeddirs_fn,
        num_ray_samples=args.N_samples,
        approx_nonrigid_viewdirs=args.approx_nonrigid_viewdirs,
        time_conditioned_baseline=args.time_conditioned_baseline,
    ).cuda()
    grad_vars += list(
        model.parameters()
    )  # model.parameters() does not contain ray_bender parameters

    model_fine = None
    if args.N_importance > 0:
        model_fine = NeRF(
            D=args.netdepth_fine,
            W=args.netwidth_fine,
            input_ch=input_ch,
            output_ch=output_ch,
            skips=skips,
            input_ch_views=input_ch_views,
            use_viewdirs=args.use_viewdirs,
            ray_bender=ray_bender,
            ray_bending_latent_size=args.ray_bending_latent_size,
            embeddirs_fn=embeddirs_fn,
            num_ray_samples=args.N_samples + args.N_importance,
            approx_nonrigid_viewdirs=args.approx_nonrigid_viewdirs,
            time_conditioned_baseline=args.time_conditioned_baseline,
        ).cuda()
        grad_vars += list(model_fine.parameters())

    def network_query_fn(
        inputs,
        viewdirs,
        additional_pixel_information,
        network_fn,
        detailed_output=False,
    ):
        return run_network(
            inputs,
            viewdirs,
            additional_pixel_information,
            network_fn,
            embed_fn=embed_fn,
            embeddirs_fn=embeddirs_fn,
            netchunk=args.netchunk,
            detailed_output=detailed_output,
        )

    # Create optimizer
    # Note: needs to be Adam. otherwise need to check how to avoid wrong DeepSDF-style autodecoder optimization of the per-frame latent codes.
    if ignore_optimizer:
        optimizer = None
    else:
        optimizer = torch.optim.Adam(
            params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)
        )

    start = 0
    logdir = os.path.join(args.rootdir, args.expname, "logs/")
    expname = args.expname

    ##########################

    # Load checkpoints
    if args.ft_path is not None and args.ft_path != "None":
        ckpts = [args.ft_path]
    else:
        ckpts = [
            os.path.join(logdir, f) for f in sorted(os.listdir(logdir)) if ".tar" in f
        ]

    print("Found ckpts", ckpts)
    if len(ckpts) > 0 and not args.no_reload:
        ckpt_path = ckpts[-1]
        print("Reloading from", ckpt_path)
        ckpt = torch.load(ckpt_path)

        start = ckpt["global_step"]
        if not ignore_optimizer:
            optimizer.load_state_dict(ckpt["optimizer_state_dict"])

        # Load model
        model.load_state_dict(ckpt["network_fn_state_dict"])
        if model_fine is not None:
            model_fine.load_state_dict(ckpt["network_fine_state_dict"])
        if ray_bender is not None:
            ray_bender.load_state_dict(ckpt["ray_bender_state_dict"])
        if autodecoder_variables is not None:
            for latent, saved_latent in zip(
                autodecoder_variables, ckpt["ray_bending_latent_codes"]
            ):
                latent.data[:] = saved_latent[:].detach().clone()

    ##########################

    render_kwargs_train = {
        "network_query_fn": network_query_fn,
        "perturb": args.perturb,
        "N_importance": args.N_importance,
        "network_fine": model_fine,
        "N_samples": args.N_samples,
        "network_fn": model,
        "ray_bender": ray_bender,
        "use_viewdirs": args.use_viewdirs,
        "white_bkgd": False,
        "raw_noise_std": args.raw_noise_std,
    }

    # NDC only good for LLFF-style forward facing data
    # if args.dataset_type != 'llff' or args.no_ndc:
    #    print('Not ndc!')
    render_kwargs_train["ndc"] = False
    render_kwargs_train["lindisp"] = False

    render_kwargs_test = {k: render_kwargs_train[k] for k in render_kwargs_train}
    render_kwargs_test["perturb"] = False
    render_kwargs_test["raw_noise_std"] = 0.0

    return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer