in train.py [0:0]
def main_function(args):
# miscellaneous initial stuff
global DEBUG
DEBUG = args.debug
torch.autograd.set_detect_anomaly(args.debug)
if args.seed >= 0:
np.random.seed(args.seed)
# Load data
if args.dataset_type == "llff":
#images, poses, bds, render_poses, i_test = load_llff_data_multi_view(
images, poses, bds, render_poses, i_test = load_llff_data(
args.datadir,
factor=args.factor,
recenter=True,
bd_factor=args.bd_factor,
spherify=args.spherify,
)
dataset_extras = _get_multi_view_helper_mappings(images.shape[0], args.datadir)
intrinsics, image_folder = get_full_resolution_intrinsics(args, dataset_extras)
hwf = poses[0, :3, -1]
poses = poses[:, :3, :4]
print("Loaded llff", images.shape, render_poses.shape, hwf, args.datadir)
# check if height, width, focal_x and focal_y are None. if so, use hwf to set them in intrinsics
# do not use this for loop and the next in smallscripts. instead rely on the stored/saved version of "intrinsics"
for camera in intrinsics.values(): # downscale according to args.factor
camera["height"] = images.shape[1]
camera["width"] = images.shape[2]
if camera["focal_x"] is None:
camera["focal_x"] = hwf[2]
else:
camera["focal_x"] /= args.factor
if camera["focal_y"] is None:
camera["focal_y"] = hwf[2]
else:
camera["focal_y"] /= args.factor
camera["center_x"] /= args.factor
camera["center_y"] /= args.factor
# modify "intrinsics" mapping to use viewid instead of raw_view
for raw_view in list(intrinsics.keys()):
viewid = dataset_extras["rawview_to_viewid"][raw_view]
new_entry = intrinsics[raw_view]
del intrinsics[raw_view]
intrinsics[viewid] = new_entry
# take out chunks (args parameters: train & test block lengths)
i_test = [] # [i_test]
if args.test_block_size > 0 and args.train_block_size > 0:
print(
"splitting timesteps into training ("
+ str(args.train_block_size)
+ ") and test ("
+ str(args.test_block_size)
+ ") blocks"
)
num_timesteps = len(dataset_extras["raw_timesteps"])
test_timesteps = np.concatenate(
[
np.arange(
min(num_timesteps, blocks_start + args.train_block_size),
min(
num_timesteps,
blocks_start + args.train_block_size + args.test_block_size,
),
)
for blocks_start in np.arange(
0, num_timesteps, args.train_block_size + args.test_block_size
)
]
)
i_test = [
imageid
for imageid, timestep in enumerate(
dataset_extras["imageid_to_timestepid"]
)
if timestep in test_timesteps
]
i_test = np.array(i_test)
i_val = i_test
i_train = np.array(
[
i
for i in np.arange(int(images.shape[0]))
if (i not in i_test and i not in i_val)
]
)
print("DEFINING BOUNDS")
# if args.no_ndc:
near = np.ndarray.min(bds) * 0.9
far = np.ndarray.max(bds) * 1.0
# else:
# near = 0.
# far = 1.
print("NEAR FAR", near, far)
else:
print("Unknown dataset type", args.dataset_type, "exiting")
return
if args.render_test:
render_poses = np.array(poses[i_test])
# Create log dir and copy the config file
logdir = os.path.join(args.rootdir, args.expname, "logs/")
expname = args.expname
os.makedirs(logdir, exist_ok=True)
f = os.path.join(logdir, "args.txt")
with open(f, "w") as file:
for arg in sorted(vars(args)):
attr = getattr(args, arg)
file.write("{} = {}\n".format(arg, attr))
if args.config is not None:
f = os.path.join(logdir, "config.txt")
with open(f, "w") as file:
file.write(open(args.config, "r").read())
# create autodecoder variables as pytorch tensors
ray_bending_latents_list = [
torch.zeros(args.ray_bending_latent_size).cuda()
for _ in range(len(dataset_extras["raw_timesteps"]))
]
for latent in ray_bending_latents_list:
latent.requires_grad = True
# Create nerf model
render_kwargs_train, render_kwargs_test, start, grad_vars, optimizer = create_nerf(
args, autodecoder_variables=ray_bending_latents_list
)
print("start: " + str(start) + " args.N_iters: " + str(args.N_iters), flush=True)
global_step = start
bds_dict = {
"near": near,
"far": far,
}
render_kwargs_train.update(bds_dict)
render_kwargs_test.update(bds_dict)
scripts_dict = {"near": near, "far": far, "image_folder": image_folder}
coarse_model = render_kwargs_train["network_fn"]
fine_model = render_kwargs_train["network_fine"]
ray_bender = render_kwargs_train["ray_bender"]
parallel_training = get_parallelized_training_function(
coarse_model=coarse_model,
latents=ray_bending_latents_list,
fine_model=fine_model,
ray_bender=ray_bender,
)
parallel_render = get_parallelized_render_function(
coarse_model=coarse_model, fine_model=fine_model, ray_bender=ray_bender
) # only used by render_path() at test time, not for training/optimization
min_point, max_point = determine_nerf_volume_extent(
parallel_render, poses, [ intrinsics[dataset_extras["imageid_to_viewid"][imageid]] for imageid in range(poses.shape[0]) ], render_kwargs_train, args
)
scripts_dict["min_nerf_volume_point"] = min_point.detach().cpu().numpy().tolist()
scripts_dict["max_nerf_volume_point"] = max_point.detach().cpu().numpy().tolist()
# Move testing data to GPU
render_poses = torch.Tensor(render_poses).cuda()
# Prepare raybatch tensor if batching random rays
N_rand = args.N_rand
# For random ray batching
print("get rays")
rays = np.stack([get_rays_np(p, intrinsics[dataset_extras["imageid_to_viewid"][imageid]]) for imageid, p in enumerate(poses[:,:3,:4])], 0) # [N, ro+rd, H, W, 3]
print("done, concats")
# attach index information (index among all images in dataset, x and y coordinate)
image_indices, y_coordinates, x_coordinates = np.meshgrid(
np.arange(images.shape[0]), np.arange(intrinsics[0]["height"]), np.arange(intrinsics[0]["width"]), indexing="ij"
) # keep consistent with code in get_rays and get_rays_np. (0,0,0) is coordinate of the top-left corner of the first image, i.e. of [0,0,0]. each array has shape images x height x width
additional_indices = np.stack(
[image_indices, x_coordinates, y_coordinates], axis=-1
) # N x height x width x 3 (image, x, y)
rays_rgb = np.concatenate(
[rays, images[:, None], additional_indices[:, None]], 1
) # [N, ro+rd+rgb+ind, H, W, 3]
rays_rgb = np.transpose(rays_rgb, [0, 2, 3, 1, 4]) # [N, H, W, ro+rd+rgb+ind, 3]
# use all images
# keep shape N x H x W x ro+rd+rgb x 3
rays_rgb = rays_rgb.astype(np.float32)
print(rays_rgb.shape)
# Move training data to GPU
poses = torch.Tensor(poses).cuda()
# N_iters = 200000 + 1
N_iters = args.N_iters + 1
print("TRAIN views are", i_train)
print("TEST views are", i_test)
print("VAL views are", i_val)
print("Begin", flush=True)
# Summary writers
# writer = SummaryWriter(os.path.join(logdir, 'summaries', expname))
start = start + 1
for i in trange(start, N_iters):
time0 = time.time()
optimizer.zero_grad()
# reset autodecoder gradients to avoid wrong DeepSDF-style optimization. Note: this is only guaranteed to work if the optimizer is Adam
for latent in ray_bending_latents_list:
latent.grad = None
# Sample random ray batch
# Random over all images
# use np random to samples N_rand random image IDs, x and y values
image_indices = np.random.randint(images.shape[0], size=args.N_rand)
x_coordinates = np.random.randint(intrinsics[0]["width"], size=args.N_rand)
y_coordinates = np.random.randint(intrinsics[0]["height"], size=args.N_rand)
# index rays_rgb with those values
batch = rays_rgb[
image_indices, y_coordinates, x_coordinates
] # batch x ro+rd+rgb+ind x 3
# push to cuda, create batch_rays, target_s, batch_pixel_indices
batch_pixel_indices = (
torch.Tensor(
np.stack([image_indices, x_coordinates, y_coordinates], axis=-1)
)
.cuda()
.long()
) # batch x 3
batch = torch.transpose(torch.Tensor(batch).cuda(), 0, 1) # 4 x batch x 3
batch_rays, target_s = batch[:2], batch[2]
losses = parallel_training(
args,
batch_rays[0],
batch_rays[1],
i,
render_kwargs_train,
target_s,
global_step,
start,
dataset_extras,
batch_pixel_indices,
)
# losses will have shape N_rays
all_test_images_indicator = torch.zeros(images.shape[0], dtype=np.long).cuda()
all_test_images_indicator[i_test] = 1
all_training_images_indicator = torch.zeros(
images.shape[0], dtype=np.long
).cuda()
all_training_images_indicator[i_train] = 1
# index with image IDs of the N_rays rays to determine weights
current_test_images_indicator = all_test_images_indicator[
image_indices
] # N_rays
current_training_images_indicator = all_training_images_indicator[
image_indices
] # N_rays
# first, test_images (if sampled image IDs give non-empty indicators). mask N_rays loss with indicators, then take mean and loss backward with retain_graph=True. then None ray_bender (if existent) and Nerf grads
if ray_bender is not None and torch.sum(current_test_images_indicator) > 0:
masked_loss = current_test_images_indicator * losses # N_rays
masked_loss = torch.mean(masked_loss)
masked_loss.backward(retain_graph=True)
for weights in (
list(coarse_model.parameters())
+ list([] if fine_model is None else fine_model.parameters())
+ list([] if ray_bender is None else ray_bender.parameters())
):
weights.grad = None
# next, training images (always). mask N_rays loss with indicators, then take mean and loss backward WITHOUT retain_graph=True
masked_loss = current_training_images_indicator * losses # N_rays
masked_loss = torch.mean(masked_loss)
masked_loss.backward(retain_graph=False)
optimizer.step()
if DEBUG:
if torch.isnan(losses).any() or torch.isinf(losses).any():
raise RuntimeError(str(losses))
if torch.isnan(target_s).any() or torch.isinf(target_s).any():
raise RuntimeError(str(torch.sum(target_s)) + " " + str(target_s))
norm_type = 2.0
total_gradient_norm = 0
for p in (
list(coarse_model.parameters())
+ list(fine_model.parameters())
+ list(ray_bender.parameters())
+ list(ray_bending_latents_list)
):
if p.requires_grad and p.grad is not None:
param_norm = p.grad.data.norm(norm_type)
total_gradient_norm += param_norm.item() ** norm_type
total_gradient_norm = total_gradient_norm ** (1.0 / norm_type)
print(total_gradient_norm, flush=True)
# NOTE: IMPORTANT!
### update learning rate ###
decay_rate = 0.1
decay_steps = args.lrate_decay
new_lrate = args.lrate * (decay_rate ** (global_step / decay_steps))
warming_up = 1000
if (
global_step < warming_up
): # in case images are very dark or very bright, need to keep network from initially building up so much momentum that it kills the gradient
new_lrate /= 20.0 * (-(global_step - warming_up) / warming_up) + 1.0
for param_group in optimizer.param_groups:
param_group["lr"] = new_lrate
################################
dt = time.time() - time0
log_string = (
"Step: "
+ str(global_step)
+ ", total loss: "
+ str(losses.mean().cpu().detach().numpy())
)
if "img_loss0" in locals():
log_string += ", coarse loss: " + str(
img_loss0.mean().cpu().detach().numpy()
)
if "img_loss" in locals():
log_string += ", fine loss: " + str(img_loss.mean().cpu().detach().numpy())
if "offsets_loss" in locals():
log_string += ", offsets: " + str(
offsets_loss.mean().cpu().detach().numpy()
)
if "divergence_loss" in locals():
log_string += ", div: " + str(divergence_loss.mean().cpu().detach().numpy())
log_string += ", time: " + str(dt)
print(log_string, flush=True)
# Rest is logging
if i % args.i_weights == 0:
all_latents = torch.zeros(0)
for l in ray_bending_latents_list:
all_latents = torch.cat([all_latents, l.cpu().unsqueeze(0)], 0)
if i % 50000 == 0:
store_extra = True
path = os.path.join(logdir, "{:06d}.tar".format(i))
else:
store_extra = False
path = os.path.join(logdir, "latest.tar")
torch.save(
{
"global_step": global_step,
"network_fn_state_dict": render_kwargs_train[
"network_fn"
].state_dict(),
"network_fine_state_dict": None
if render_kwargs_train["network_fine"] is None
else render_kwargs_train["network_fine"].state_dict(),
"ray_bender_state_dict": None
if render_kwargs_train["ray_bender"] is None
else render_kwargs_train["ray_bender"].state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"ray_bending_latent_codes": all_latents, # shape: frames x latent_size
"intrinsics": intrinsics,
"scripts_dict": scripts_dict,
"dataset_extras": dataset_extras,
},
path,
)
del all_latents
if store_extra:
shutil.copyfile(path, os.path.join(logdir, "latest.tar"))
print("Saved checkpoints at", path)
if i % args.i_video == 0 and i > 0:
# Turn on testing mode
print("rendering test set...", flush=True)
if len(render_poses) > 0 and len(i_test) > 0 and not dataset_extras["is_multiview"]:
with torch.no_grad():
if args.render_test:
rendering_latents = ray_bending_latents = [
ray_bending_latents_list[
dataset_extras["imageid_to_timestepid"][i]
]
for i in i_test
]
else:
rendering_latents = ray_bending_latents = [
ray_bending_latents_list[
dataset_extras["imageid_to_timestepid"][i_test[0]]
]
for _ in range(len(render_poses))
]
rgbs, disps = render_path(
render_poses,
[intrinsics[0] for _ in range(len(render_poses))],
args.chunk,
render_kwargs_test,
ray_bending_latents=rendering_latents,
parallelized_render_function=parallel_render,
)
print("Done, saving", rgbs.shape, disps.shape)
moviebase = os.path.join(logdir, "{}_spiral_{:06d}_".format(expname, i))
try:
imageio.mimwrite(
moviebase + "rgb.mp4", to8b(rgbs), fps=30, quality=8
)
imageio.mimwrite(
moviebase + "disp.mp4",
to8b(disps / np.max(disps)),
fps=30,
quality=8,
)
imageio.mimwrite(
moviebase + "disp_jet.mp4",
to8b(
np.stack(
[
visualize_disparity_with_jet_color_scheme(
disp / np.max(disp)
)
for disp in disps
],
axis=0,
)
),
fps=30,
quality=8,
)
imageio.mimwrite(
moviebase + "disp_phong.mp4",
to8b(
np.stack(
[
visualize_disparity_with_blinn_phong(
disp / np.max(disp)
)
for disp in disps
],
axis=0,
)
),
fps=30,
quality=8,
)
except:
print(
"imageio.mimwrite() failed. maybe ffmpeg is not installed properly?"
)
if i >= N_iters + 1 - args.i_video:
print("rendering full training set...", flush=True)
with torch.no_grad():
rgbs, disps = render_path(
poses[i_train],
[intrinsics[dataset_extras["imageid_to_viewid"][imageid]] for imageid in i_train],
args.chunk,
render_kwargs_test,
ray_bending_latents=[
ray_bending_latents_list[
dataset_extras["imageid_to_timestepid"][i]
]
for i in i_train
],
parallelized_render_function=parallel_render,
)
print("Done, saving", rgbs.shape, disps.shape)
moviebase = os.path.join(
logdir, "{}_training_{:06d}_".format(expname, i)
)
try:
imageio.mimwrite(
moviebase + "rgb.mp4", to8b(rgbs), fps=30, quality=8
)
imageio.mimwrite(
moviebase + "disp.mp4",
to8b(disps / np.max(disps)),
fps=30,
quality=8,
)
imageio.mimwrite(
moviebase + "disp_jet.mp4",
to8b(
np.stack(
[
visualize_disparity_with_jet_color_scheme(
disp / np.max(disp)
)
for disp in disps
],
axis=0,
)
),
fps=30,
quality=8,
)
imageio.mimwrite(
moviebase + "disp_phong.mp4",
to8b(
np.stack(
[
visualize_disparity_with_blinn_phong(
disp / np.max(disp)
)
for disp in disps
],
axis=0,
)
),
fps=30,
quality=8,
)
except:
print(
"imageio.mimwrite() failed. maybe ffmpeg is not installed properly?"
)
if i % args.i_testset == 0 and i > 0:
trainsubsavedir = os.path.join(logdir, "trainsubset_{:06d}".format(i))
os.makedirs(trainsubsavedir, exist_ok=True)
i_train_sub = i_train
if i >= N_iters + 1 - args.i_video:
i_train_sub = i_train_sub
else:
i_train_sub = i_train_sub[
:: np.maximum(1, int((len(i_train_sub) / len(i_test)) + 0.5))
]
print("i_train_sub poses shape", poses[i_train_sub].shape)
with torch.no_grad():
render_path(
poses[i_train_sub],
[intrinsics[dataset_extras["imageid_to_viewid"][imageid]] for imageid in i_train_sub],
args.chunk,
render_kwargs_test,
gt_imgs=images[i_train_sub],
savedir=trainsubsavedir,
detailed_output=True,
ray_bending_latents=[
ray_bending_latents_list[
dataset_extras["imageid_to_timestepid"][i]
]
for i in i_train_sub
],
parallelized_render_function=parallel_render,
)
print("Saved some training images")
if len(i_test) > 0:
testsavedir = os.path.join(logdir, "testset_{:06d}".format(i))
os.makedirs(testsavedir, exist_ok=True)
print("test poses shape", poses[i_test].shape)
with torch.no_grad():
render_path(
poses[i_test],
[intrinsics[dataset_extras["imageid_to_viewid"][imageid]] for imageid in i_test],
args.chunk,
render_kwargs_test,
gt_imgs=images[i_test],
savedir=testsavedir,
detailed_output=True,
ray_bending_latents=[
ray_bending_latents_list[
dataset_extras["imageid_to_timestepid"][i]
]
for i in i_test
],
parallelized_render_function=parallel_render,
)
print("Saved test set")
if i % args.i_print == 0:
if "psnr" in locals():
tqdm.write(
f"[TRAIN] Iter: {i} Loss: {losses.mean().item()} PSNR: {psnr.item()}"
)
else:
tqdm.write(f"[TRAIN] Iter: {i} Loss: {losses.mean().item()}")
"""
print(expname, i, psnr.numpy(), loss.numpy(), global_step.numpy())
print('iter time {:.05f}'.format(dt))
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_print):
tf.contrib.summary.scalar('loss', loss)
tf.contrib.summary.scalar('psnr', psnr)
tf.contrib.summary.histogram('tran', trans)
if args.N_importance > 0:
tf.contrib.summary.scalar('psnr0', psnr0)
if i%args.i_img==0:
# Log a rendered validation view to Tensorboard
img_i=np.random.choice(i_val)
target = images[img_i]
pose = poses[img_i, :3,:4]
with torch.no_grad():
rgb, disp, acc, extras = render(H, W, focal, chunk=args.chunk, c2w=pose,
**render_kwargs_test)
psnr = mse2psnr(img2mse(rgb, target))
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
tf.contrib.summary.image('rgb', to8b(rgb)[tf.newaxis])
tf.contrib.summary.image('disp', disp[tf.newaxis,...,tf.newaxis])
tf.contrib.summary.image('acc', acc[tf.newaxis,...,tf.newaxis])
tf.contrib.summary.scalar('psnr_holdout', psnr)
tf.contrib.summary.image('rgb_holdout', target[tf.newaxis])
if args.N_importance > 0:
with tf.contrib.summary.record_summaries_every_n_global_steps(args.i_img):
tf.contrib.summary.image('rgb0', to8b(extras['rgb0'])[tf.newaxis])
tf.contrib.summary.image('disp0', extras['disp0'][tf.newaxis,...,tf.newaxis])
tf.contrib.summary.image('z_std', extras['z_std'][tf.newaxis,...,tf.newaxis])
"""
global_step += 1
print("", end="", flush=True)