in depth_fine_tuning.py [0:0]
def fine_tune(self, writer=None):
meta_file = None
if self.params.recon == "colmap":
if self.params.scaling == "extrinsics":
meta_file = pjoin(self.range_dir, "metadata_scaled.npz")
else:
meta_file = pjoin(self.base_dir, "colmap_dense", "metadata.npz")
print("Start depth finetuning...")
use_temporal_smooth_loss = (
self.params.lambda_smooth_disparity > 0
or self.params.lambda_smooth_reprojection > 0
or self.params.lambda_smooth_depth_ratio > 0
)
dataset = VideoDataset(
self.base_dir,
self.frames,
self.params.min_mask_ratio,
use_temporal_smooth_loss,
meta_file,
self.params.recon,
)
train_data_loader = DataLoader(
dataset,
batch_size=self.params.batch_size,
shuffle=True,
num_workers=4,
pin_memory=torch.cuda.is_available(),
)
val_data_loader = DataLoader(
dataset,
batch_size=self.params.batch_size,
shuffle=False,
num_workers=4,
pin_memory=torch.cuda.is_available(),
)
# Even if we're using the COLMAP pipeline, we're initializing the pose
# optimizer here, because it will create a depth video container for us.
pose_optimizer = PoseOptimizer(
self.base_dir, self.params.model_type, self.frames, self.params.opt
)
if self.params.recon == "i3d":
pose_optimizer.optimize_poses()
if self.params.save_intermediate_depth_streams_freq > 0:
self.depth_dir = os.path.join(self.out_dir, "depth_e0000")
pose_optimizer.duplicate_last_depth_stream("e0000", self.depth_dir)
else:
self.depth_dir = self.out_dir
pose_optimizer.duplicate_last_depth_stream("fine_tuned", self.depth_dir)
if self.params.recon == "i3d":
dataset.update_poses(pose_optimizer.depth_video)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
# Only enable back-propagation for the PCA realted parameters if specified.
if self.params.model_type == "midas2_pca":
for name, param in self.model.named_parameters():
if name == "model.scale_params" or name == "model.shift_params":
param.requires_grad = True
else:
param.requires_grad = False
# Only cover tunable PCA parameters in loss computation.
criterion = JointLoss(
self.params, parameters_init=[p.clone() for p in self.model.parameters()]
)
if self.params.save_tensorboard and writer is None:
if self.params.tensorboard_log_path:
log_dir = self.params.tensorboard_log_path
else:
log_dir = pjoin(self.out_dir, "tensorboard")
# Print the prompt to view the tensorboard.
print(get_tensorboard_prompt(log_dir))
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
# Only include tunable PCA parameters in the optimizer if specified.
if self.params.model_type == "midas2_pca":
opt = optimizer.create(
self.params.optimizer,
filter(lambda p: p.requires_grad, self.model.parameters()),
self.params.learning_rate,
betas=(0.9, 0.999),
)
else:
opt = optimizer.create(
self.params.optimizer,
self.model.parameters(),
self.params.learning_rate,
betas=(0.9, 0.999),
)
self.model.train()
def validate(epoch, niters):
val_start_time = time.perf_counter()
loss_meta = self.eval_and_save(criterion, val_data_loader, epoch, niters)
if writer is not None:
log_loss_stats(
writer, "validation", loss_meta, epoch, log_histogram=True
)
val_end_time = time.perf_counter()
val_duration = val_end_time - val_start_time
print(
f"Complete Validation for epoch {epoch} ({niters} iterations) in {val_duration:.2f}s."
)
if self.params.val_epoch_freq >= 0:
validate(epoch=0, niters=0)
# Disable inplace relu for batch-wise PCA modulation
def disable_relu_inplace(model) -> None:
for child_name, child in model.named_children():
if isinstance(child, torch.nn.ReLU):
setattr(model, child_name, torch.nn.ReLU(inplace=False))
else:
disable_relu_inplace(child)
# Retrieve intially computed depth predictions for loss computation.
# depth_fmt = "frame_{:06d}.raw"
initial_depth_dir = osp.join(self.base_dir, f"depth_{self.params.model_type}", "depth")
depth_names = [
n for n in os.listdir(initial_depth_dir) if os.path.splitext(n)[-1] == ".raw"
]
depth_names = sorted(depth_names)
all_depth_orig = {}
for depth_name in depth_names:
depth_path = osp.join(initial_depth_dir, depth_name)
depth_orig = 1.0 / image_io.load_raw_float32_image(depth_path)
all_depth_orig[depth_name] = torch.from_numpy(depth_orig)
def retrieve_depth_orig(metadata) -> torch.Tensor:
"""
Retrieve the corresponding original depths for loss computation.
"""
indices = metadata["geometry_consistency"]["indices"]
indices_list = indices.cpu().numpy().tolist()
indices_list = list(itertools.chain(*indices_list))
depth_orig = []
for idx in indices_list:
depth_orig.append(all_depth_orig.get(f"frame_{idx:06d}.raw"))
self.depth_orig = torch.stack(depth_orig)
return self.depth_orig
# Training loop.
total_iters = 0
for epoch in range(self.params.num_epochs):
epoch_start_time = time.perf_counter()
for data in train_data_loader:
if self.params.model_type == "midas2_pca":
print(f"'scale_params': {self.model.model.scale_params}")
print(f"'shift_params': {self.model.model.shift_params}")
iter_start_time = time.perf_counter()
data = to_device(data)
stacked_img, metadata = data
print(f"Size of stacked_img: {stacked_img.shape}")
print(f"Current batch_size: {self.params.batch_size}")
depth = self.model(stacked_img, metadata)
# Apply per-frame scales
if self.params.recon == "colmap" and self.params.scaling == "depth":
indices = metadata["geometry_consistency"]["indices"]
scale = torch.Tensor(
indices.shape[0], indices.shape[1], 1, 1
).cuda()
for pair in range(indices.shape[0]):
for i in range(2):
frame = int(indices[pair][i])
ref_disp = self.load_reference_disparity(frame)
valid = ~np.logical_or(
np.isinf(ref_disp), np.isnan(ref_disp)
)
est_disp = 1.0 / depth[pair, i, :].detach().cpu()
pixel_scales = (est_disp / ref_disp)[valid]
image_scale = np.median(pixel_scales)
scale[pair, i] = float(image_scale)
print(f"Frame {frame}: scale = {image_scale}.")
depth = depth * scale
opt.zero_grad()
# Retrieve original depth predictions for contrast loss computation.
depth_orig = retrieve_depth_orig(metadata)
_, h, w = depth_orig.shape
# Reshape (x, h, w) to (b, n, h, w) to match depth.
depth_orig = depth_orig.view(-1, 2, h, w)
depth_orig = depth_orig.to(depth.device)
# Loss computation.
loss, loss_meta, _ = criterion(
stacked_img,
depth_orig,
depth,
metadata,
parameters=self.model.parameters(),
)
pairs = metadata["geometry_consistency"]["indices"]
pairs = pairs.cpu().numpy().tolist()
print(f"Epoch = {epoch}, pairs = {pairs}, loss = {loss[0]}")
if torch.isnan(loss):
print("Loss is NaN. Skipping.")
continue
loss.backward()
opt.step()
total_iters += stacked_img.shape[0]
print(f"total_iters: {total_iters}")
if writer is not None and total_iters % self.params.print_freq == 0:
log_loss(writer, "Train", loss, loss_meta, total_iters)
if writer is not None and total_iters % self.params.display_freq == 0:
write_summary(
writer, "Train", stacked_img, depth, metadata, total_iters
)
iter_end_time = time.perf_counter()
iter_duration = iter_end_time - iter_start_time
print(f"Iteration took {iter_duration:.2f}s.")
epoch_end_time = time.perf_counter()
epoch_duration = epoch_end_time - epoch_start_time
print(f"Epoch {epoch} took {epoch_duration:.2f}s.")
if (
self.params.val_epoch_freq >= 0
and (epoch + 1) % self.params.val_epoch_freq == 0
):
validate(epoch + 1, total_iters)
if (
self.params.save_checkpoints
and (epoch + 1) % self.params.save_epoch_freq == 0
):
file_name = pjoin(self.checkpoints_dir, f"{epoch + 1:04d}.pth")
self.model.save(file_name)
if (
self.params.save_intermediate_depth_streams_freq > 0
and (epoch + 1) % self.params.save_intermediate_depth_streams_freq == 0
):
self.save_depth(frames=self.frames)
if (
self.params.recon == "i3d"
and (epoch + 1) % self.params.pose_opt_freq == 0
):
if self.params.save_intermediate_depth_streams_freq > 0:
# Create new depth stream for optimized poses.
epoch_str = f"e{epoch:04d}_opt"
self.depth_dir = os.path.join(self.out_dir, f"depth_{epoch_str}")
pose_optimizer.duplicate_last_depth_stream(
epoch_str, self.depth_dir
)
# Pose optimization with depth/spatial deformation
pose_opt_start_time = time.perf_counter()
pose_optimizer.optimize_poses()
dataset.update_poses(pose_optimizer.depth_video)
pose_opt_end_time = time.perf_counter()
pose_opt_duration = pose_opt_end_time - pose_opt_start_time
print(f"Complete pose optimization in {pose_opt_duration:.2f}s")
if (
self.params.save_intermediate_depth_streams_freq > 0
and (epoch + 1) % self.params.save_intermediate_depth_streams_freq
== 0
):
self.save_depth(frames=self.frames)
if (
self.params.save_intermediate_depth_streams_freq > 0
and (epoch + 1) % self.params.save_intermediate_depth_streams_freq == 0
and epoch + 1 < self.params.num_epochs
):
# Create depth stream for the next epoch.
epoch_str = f"e{epoch + 1:04d}"
self.depth_dir = os.path.join(self.out_dir, f"depth_{epoch_str}")
pose_optimizer.duplicate_last_depth_stream(epoch_str, self.depth_dir)
# Validate the last epoch, unless it was just done in the loop above.
if (
self.params.val_epoch_freq >= 0
and self.params.num_epochs % self.params.val_epoch_freq != 0
):
validate(epoch=self.params.num_epochs, niters=total_iters)
if self.params.post_filter:
pose_optimizer.filter_depth(self.params.filter_radius)
print("Finished Filtering.")