def fine_tune()

in depth_fine_tuning.py [0:0]


    def fine_tune(self, writer=None):
        meta_file = pjoin(self.range_dir, "metadata_scaled.npz")

        dataset = VideoDataset(self.base_dir, meta_file)
        train_data_loader = DataLoader(
            dataset,
            batch_size=self.params.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=torch.cuda.is_available(),
        )
        val_data_loader = DataLoader(
            dataset,
            batch_size=self.params.batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=torch.cuda.is_available(),
        )

        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True

        criterion = JointLoss(self.params,
            parameters_init=[p.clone() for p in self.model.parameters()])

        if writer is None:
            log_dir = pjoin(self.out_dir, "tensorboard")
            os.makedirs(log_dir, exist_ok=True)
            writer = SummaryWriter(log_dir=log_dir)

        opt = optimizer.create(
            self.params.optimizer,
            self.model.parameters(),
            self.params.learning_rate,
            betas=(0.9, 0.999)
        )

        eval_dir = pjoin(self.out_dir, "eval")
        os.makedirs(eval_dir, exist_ok=True)

        self.model.train()

        def suffix(epoch, niters):
            return "_e{:04d}_iter{:06d}".format(epoch, niters)

        def validate(epoch, niters):
            loss_meta = self.eval_and_save(
                criterion, val_data_loader, suffix(epoch, niters)
            )
            if writer is not None:
                log_loss_stats(
                    writer, "validation", loss_meta, epoch, log_histogram=True
                )
            print(f"Done Validation for epoch {epoch} ({niters} iterations)")

        self.vis_depth_scale = None
        validate(0, 0)

        # Training loop.
        total_iters = 0
        for epoch in range(self.params.num_epochs):
            epoch_start_time = time.perf_counter()

            for data in train_data_loader:
                data = to_device(data)
                stacked_img, metadata = data

                depth = self.model(stacked_img, metadata)

                opt.zero_grad()
                loss, loss_meta = criterion(
                    depth, metadata, parameters=self.model.parameters())

                pairs = metadata['geometry_consistency']['indices']
                pairs = pairs.cpu().numpy().tolist()

                print(f"Epoch = {epoch}, pairs = {pairs}, loss = {loss[0]}")
                if torch.isnan(loss):
                    print("Loss is NaN. Skipping.")
                    continue

                loss.backward()
                opt.step()

                total_iters += stacked_img.shape[0]

                if writer is not None and total_iters % self.params.print_freq == 0:
                    log_loss(writer, 'Train', loss, loss_meta, total_iters)

                if writer is not None and total_iters % self.params.display_freq == 0:
                    write_summary(
                        writer, 'Train', stacked_img, depth, metadata, total_iters
                    )

            epoch_end_time = time.perf_counter()
            epoch_duration = epoch_end_time - epoch_start_time
            print(f"Epoch {epoch} took {epoch_duration:.2f}s.")

            if (epoch + 1) % self.params.val_epoch_freq == 0:
                validate(epoch + 1, total_iters)

            if (epoch + 1) % self.params.save_epoch_freq == 0:
                file_name = pjoin(self.checkpoints_dir, f"{epoch + 1:04d}.pth")
                self.model.save(file_name)

        # Validate the last epoch, unless it was just done in the loop above.
        if self.params.num_epochs % self.params.val_epoch_freq != 0:
            validate(self.params.num_epochs, total_iters)

        print("Finished Training")