in depth_fine_tuning.py [0:0]
def fine_tune(self, writer=None):
meta_file = pjoin(self.range_dir, "metadata_scaled.npz")
dataset = VideoDataset(self.base_dir, meta_file)
train_data_loader = DataLoader(
dataset,
batch_size=self.params.batch_size,
shuffle=True,
num_workers=4,
pin_memory=torch.cuda.is_available(),
)
val_data_loader = DataLoader(
dataset,
batch_size=self.params.batch_size,
shuffle=False,
num_workers=4,
pin_memory=torch.cuda.is_available(),
)
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
criterion = JointLoss(self.params,
parameters_init=[p.clone() for p in self.model.parameters()])
if writer is None:
log_dir = pjoin(self.out_dir, "tensorboard")
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)
opt = optimizer.create(
self.params.optimizer,
self.model.parameters(),
self.params.learning_rate,
betas=(0.9, 0.999)
)
eval_dir = pjoin(self.out_dir, "eval")
os.makedirs(eval_dir, exist_ok=True)
self.model.train()
def suffix(epoch, niters):
return "_e{:04d}_iter{:06d}".format(epoch, niters)
def validate(epoch, niters):
loss_meta = self.eval_and_save(
criterion, val_data_loader, suffix(epoch, niters)
)
if writer is not None:
log_loss_stats(
writer, "validation", loss_meta, epoch, log_histogram=True
)
print(f"Done Validation for epoch {epoch} ({niters} iterations)")
self.vis_depth_scale = None
validate(0, 0)
# Training loop.
total_iters = 0
for epoch in range(self.params.num_epochs):
epoch_start_time = time.perf_counter()
for data in train_data_loader:
data = to_device(data)
stacked_img, metadata = data
depth = self.model(stacked_img, metadata)
opt.zero_grad()
loss, loss_meta = criterion(
depth, metadata, parameters=self.model.parameters())
pairs = metadata['geometry_consistency']['indices']
pairs = pairs.cpu().numpy().tolist()
print(f"Epoch = {epoch}, pairs = {pairs}, loss = {loss[0]}")
if torch.isnan(loss):
print("Loss is NaN. Skipping.")
continue
loss.backward()
opt.step()
total_iters += stacked_img.shape[0]
if writer is not None and total_iters % self.params.print_freq == 0:
log_loss(writer, 'Train', loss, loss_meta, total_iters)
if writer is not None and total_iters % self.params.display_freq == 0:
write_summary(
writer, 'Train', stacked_img, depth, metadata, total_iters
)
epoch_end_time = time.perf_counter()
epoch_duration = epoch_end_time - epoch_start_time
print(f"Epoch {epoch} took {epoch_duration:.2f}s.")
if (epoch + 1) % self.params.val_epoch_freq == 0:
validate(epoch + 1, total_iters)
if (epoch + 1) % self.params.save_epoch_freq == 0:
file_name = pjoin(self.checkpoints_dir, f"{epoch + 1:04d}.pth")
self.model.save(file_name)
# Validate the last epoch, unless it was just done in the loop above.
if self.params.num_epochs % self.params.val_epoch_freq != 0:
validate(self.params.num_epochs, total_iters)
print("Finished Training")