in train.py [0:0]
def create_nerf(args, autodecoder_variables=None, ignore_optimizer=False):
"""Instantiate NeRF's MLP model."""
grad_vars = []
if autodecoder_variables is not None:
grad_vars += autodecoder_variables
embed_fn, input_ch = get_embedder(args.multires, args.i_embed)
if args.ray_bending is not None and args.ray_bending != "None":
ray_bender = ray_bending(
input_ch, args.ray_bending_latent_size, args.ray_bending, embed_fn
).cuda()
grad_vars += list(ray_bender.parameters())
else:
ray_bender = None
if args.time_conditioned_baseline:
if args.ray_bending == "simple_neural":
raise RuntimeError("Naive Baseline requires to turn off ray bending")
if args.offsets_loss_weight > 0. or args.divergence_loss_weight > 0. or args.rigidity_loss_weight > 0.:
raise RuntimeError("Naive Baseline requires to turn off regularization losses since they only work with ray bending")
input_ch_views = 0
embeddirs_fn = None
if args.use_viewdirs:
embeddirs_fn, input_ch_views = get_embedder(args.multires_views, args.i_embed)
if args.approx_nonrigid_viewdirs:
# netchunk needs to be divisible by both number of samples of coarse and fine Nerfs
def lcm(x, y):
from math import gcd
return x * y // gcd(x, y)
needs_to_divide = lcm(args.N_samples, args.N_samples + args.N_importance)
args.netchunk = int(args.netchunk / needs_to_divide) * needs_to_divide
output_ch = 5 if args.N_importance > 0 else 4
skips = [4]
model = NeRF(
D=args.netdepth,
W=args.netwidth,
input_ch=input_ch,
output_ch=output_ch,
skips=skips,
input_ch_views=input_ch_views,
use_viewdirs=args.use_viewdirs,
ray_bender=ray_bender,
ray_bending_latent_size=args.ray_bending_latent_size,
embeddirs_fn=embeddirs_fn,
num_ray_samples=args.N_samples,
approx_nonrigid_viewdirs=args.approx_nonrigid_viewdirs,
time_conditioned_baseline=args.time_conditioned_baseline,
).cuda()
grad_vars += list(
model.parameters()
) # model.parameters() does not contain ray_bender parameters
model_fine = None
if args.N_importance > 0:
model_fine = NeRF(
D=args.netdepth_fine,
W=args.netwidth_fine,
input_ch=input_ch,
output_ch=output_ch,
skips=skips,
input_ch_views=input_ch_views,
use_viewdirs=args.use_viewdirs,
ray_bender=ray_bender,
ray_bending_latent_size=args.ray_bending_latent_size,
embeddirs_fn=embeddirs_fn,
num_ray_samples=args.N_samples + args.N_importance,
approx_nonrigid_viewdirs=args.approx_nonrigid_viewdirs,
time_conditioned_baseline=args.time_conditioned_baseline,
).cuda()
grad_vars += list(model_fine.parameters())
def network_query_fn(
inputs,
viewdirs,
additional_pixel_information,
network_fn,
detailed_output=False,
):
return run_network(
inputs,
viewdirs,
additional_pixel_information,
network_fn,
embed_fn=embed_fn,
embeddirs_fn=embeddirs_fn,
netchunk=args.netchunk,
detailed_output=detailed_output,
)
# Create optimizer
# Note: needs to be Adam. otherwise need to check how to avoid wrong DeepSDF-style autodecoder optimization of the per-frame latent codes.
if ignore_optimizer:
optimizer = None
else:
optimizer = torch.optim.Adam(
params=grad_vars, lr=args.lrate, betas=(0.9, 0.999)
)
start = 0
logdir = os.path.join(args.rootdir, args.expname, "logs/")
expname = args.expname
##########################
# Load checkpoints
if args.ft_path is not None and args.ft_path != "None":
ckpts = [args.ft_path]
else:
ckpts = [
os.path.join(logdir, f) for f in sorted(os.listdir(logdir)) if ".tar" in f
]
print("Found ckpts", ckpts)
if len(ckpts) > 0 and not args.no_reload:
ckpt_path = ckpts[-1]
print("Reloading from", ckpt_path)
ckpt = torch.load(ckpt_path)
start = ckpt["global_step"]
if not ignore_optimizer:
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
# Load model
model.load_state_dict(ckpt["network_fn_state_dict"])
if model_fine is not None:
model_fine.load_state_dict(ckpt["network_fine_state_dict"])
if ray_bender is not None:
ray_bender.load_state_dict(ckpt["ray_bender_state_dict"])
if autodecoder_variables is not None:
for latent, saved_latent in zip(
autodecoder_variables, ckpt["ray_bending_latent_codes"]
):
latent.data[:] = saved_latent[:].detach().clone()
##########################
render_kwargs_train = {
"network_query_fn": network_query_fn,
"perturb": args.perturb,
"N_importance": args.N_importance,
"network_fine": model_fine,
"N_samples": args.N_samples,
"network_fn": model,
"ray_bender": ray_bender,
"use_viewdirs": args.use_viewdirs,
"white_bkgd": False,
"raw_noise_std": args.raw_noise_std,
}
# NDC only good for LLFF-style forward facing data
# if args.dataset_type != 'llff' or args.no_ndc:
# print('Not ndc!')
render_kwargs_train["ndc"] = False
render_kwargs_train["lindisp"] = False
render_kwargs_test = {k: render_kwargs_train[k] for k in render_kwargs_train}
render_kwargs_test["perturb"] = False
render_kwargs_test["raw_noise_std"] = 0.0
return render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer