def fine_tune()

in depth_fine_tuning.py [0:0]


    def fine_tune(self, writer=None):
        meta_file = None

        if self.params.recon == "colmap":
            if self.params.scaling == "extrinsics":
                meta_file = pjoin(self.range_dir, "metadata_scaled.npz")
            else:
                meta_file = pjoin(self.base_dir, "colmap_dense", "metadata.npz")

        print("Start depth finetuning...")

        use_temporal_smooth_loss = (
            self.params.lambda_smooth_disparity > 0
            or self.params.lambda_smooth_reprojection > 0
            or self.params.lambda_smooth_depth_ratio > 0
        )

        dataset = VideoDataset(
            self.base_dir,
            self.frames,
            self.params.min_mask_ratio,
            use_temporal_smooth_loss,
            meta_file,
            self.params.recon,
        )
        train_data_loader = DataLoader(
            dataset,
            batch_size=self.params.batch_size,
            shuffle=True,
            num_workers=4,
            pin_memory=torch.cuda.is_available(),
        )
        val_data_loader = DataLoader(
            dataset,
            batch_size=self.params.batch_size,
            shuffle=False,
            num_workers=4,
            pin_memory=torch.cuda.is_available(),
        )

        # Even if we're using the COLMAP pipeline, we're initializing the pose
        # optimizer here, because it will create a depth video container for us.
        pose_optimizer = PoseOptimizer(
            self.base_dir, self.params.model_type, self.frames, self.params.opt
        )

        if self.params.recon == "i3d":
            pose_optimizer.optimize_poses()

        if self.params.save_intermediate_depth_streams_freq > 0:
            self.depth_dir = os.path.join(self.out_dir, "depth_e0000")
            pose_optimizer.duplicate_last_depth_stream("e0000", self.depth_dir)
        else:
            self.depth_dir = self.out_dir
            pose_optimizer.duplicate_last_depth_stream("fine_tuned", self.depth_dir)

        if self.params.recon == "i3d":
            dataset.update_poses(pose_optimizer.depth_video)

        torch.backends.cudnn.enabled = True
        torch.backends.cudnn.benchmark = True

        # Only enable back-propagation for the PCA realted parameters if specified.
        if self.params.model_type == "midas2_pca":
            for name, param in self.model.named_parameters():
                if name == "model.scale_params" or name == "model.shift_params":
                    param.requires_grad = True
                else:
                    param.requires_grad = False

        # Only cover tunable PCA parameters in loss computation.
        criterion = JointLoss(
            self.params, parameters_init=[p.clone() for p in self.model.parameters()]
        )

        if self.params.save_tensorboard and writer is None:
            if self.params.tensorboard_log_path:
                log_dir = self.params.tensorboard_log_path
            else:
                log_dir = pjoin(self.out_dir, "tensorboard")

            # Print the prompt to view the tensorboard.
            print(get_tensorboard_prompt(log_dir))
            os.makedirs(log_dir, exist_ok=True)
            writer = SummaryWriter(log_dir=log_dir)

        # Only include tunable PCA parameters in the optimizer if specified.
        if self.params.model_type == "midas2_pca":
            opt = optimizer.create(
                self.params.optimizer,
                filter(lambda p: p.requires_grad, self.model.parameters()),
                self.params.learning_rate,
                betas=(0.9, 0.999),
            )
        else:
            opt = optimizer.create(
                self.params.optimizer,
                self.model.parameters(),
                self.params.learning_rate,
                betas=(0.9, 0.999),
            )

        self.model.train()

        def validate(epoch, niters):
            val_start_time = time.perf_counter()

            loss_meta = self.eval_and_save(criterion, val_data_loader, epoch, niters)
            if writer is not None:
                log_loss_stats(
                    writer, "validation", loss_meta, epoch, log_histogram=True
                )

            val_end_time = time.perf_counter()
            val_duration = val_end_time - val_start_time

            print(
                f"Complete Validation for epoch {epoch} ({niters} iterations) in {val_duration:.2f}s."
            )

        if self.params.val_epoch_freq >= 0:
            validate(epoch=0, niters=0)

        # Disable inplace relu for batch-wise PCA modulation
        def disable_relu_inplace(model) -> None:
            for child_name, child in model.named_children():
                if isinstance(child, torch.nn.ReLU):
                    setattr(model, child_name, torch.nn.ReLU(inplace=False))
                else:
                    disable_relu_inplace(child)

        # Retrieve intially computed depth predictions for loss computation.
        # depth_fmt = "frame_{:06d}.raw"
        initial_depth_dir = osp.join(self.base_dir, f"depth_{self.params.model_type}", "depth")

        depth_names = [
            n for n in os.listdir(initial_depth_dir) if os.path.splitext(n)[-1] == ".raw"
        ]
        depth_names = sorted(depth_names)

        all_depth_orig = {}
        for depth_name in depth_names:
            depth_path = osp.join(initial_depth_dir, depth_name)
            depth_orig = 1.0 / image_io.load_raw_float32_image(depth_path)
            all_depth_orig[depth_name] = torch.from_numpy(depth_orig)

        def retrieve_depth_orig(metadata) -> torch.Tensor:
            """
            Retrieve the corresponding original depths for loss computation.
            """
            indices = metadata["geometry_consistency"]["indices"]
            indices_list = indices.cpu().numpy().tolist()
            indices_list = list(itertools.chain(*indices_list))

            depth_orig = []
            for idx in indices_list:
                depth_orig.append(all_depth_orig.get(f"frame_{idx:06d}.raw"))

            self.depth_orig = torch.stack(depth_orig)

            return self.depth_orig

        # Training loop.
        total_iters = 0
        for epoch in range(self.params.num_epochs):
            epoch_start_time = time.perf_counter()

            for data in train_data_loader:

                if self.params.model_type == "midas2_pca":
                    print(f"'scale_params': {self.model.model.scale_params}")
                    print(f"'shift_params': {self.model.model.shift_params}")

                iter_start_time = time.perf_counter()

                data = to_device(data)
                stacked_img, metadata = data
                print(f"Size of stacked_img: {stacked_img.shape}")

                print(f"Current batch_size: {self.params.batch_size}")
                depth = self.model(stacked_img, metadata)

                # Apply per-frame scales
                if self.params.recon == "colmap" and self.params.scaling == "depth":
                    indices = metadata["geometry_consistency"]["indices"]
                    scale = torch.Tensor(
                        indices.shape[0], indices.shape[1], 1, 1
                    ).cuda()
                    for pair in range(indices.shape[0]):
                        for i in range(2):
                            frame = int(indices[pair][i])
                            ref_disp = self.load_reference_disparity(frame)
                            valid = ~np.logical_or(
                                np.isinf(ref_disp), np.isnan(ref_disp)
                            )
                            est_disp = 1.0 / depth[pair, i, :].detach().cpu()
                            pixel_scales = (est_disp / ref_disp)[valid]
                            image_scale = np.median(pixel_scales)
                            scale[pair, i] = float(image_scale)
                            print(f"Frame {frame}: scale = {image_scale}.")
                    depth = depth * scale

                opt.zero_grad()

                # Retrieve original depth predictions for contrast loss computation.
                depth_orig = retrieve_depth_orig(metadata)
                _, h, w = depth_orig.shape
                # Reshape (x, h, w) to (b, n, h, w) to match depth.
                depth_orig = depth_orig.view(-1, 2, h, w)
                depth_orig = depth_orig.to(depth.device)

                # Loss computation.
                loss, loss_meta, _ = criterion(
                    stacked_img,
                    depth_orig,
                    depth,
                    metadata,
                    parameters=self.model.parameters(),
                )

                pairs = metadata["geometry_consistency"]["indices"]
                pairs = pairs.cpu().numpy().tolist()

                print(f"Epoch = {epoch}, pairs = {pairs}, loss = {loss[0]}")
                if torch.isnan(loss):
                    print("Loss is NaN. Skipping.")
                    continue

                loss.backward()
                opt.step()

                total_iters += stacked_img.shape[0]

                print(f"total_iters: {total_iters}")

                if writer is not None and total_iters % self.params.print_freq == 0:
                    log_loss(writer, "Train", loss, loss_meta, total_iters)

                if writer is not None and total_iters % self.params.display_freq == 0:
                    write_summary(
                        writer, "Train", stacked_img, depth, metadata, total_iters
                    )

                iter_end_time = time.perf_counter()
                iter_duration = iter_end_time - iter_start_time
                print(f"Iteration took {iter_duration:.2f}s.")

            epoch_end_time = time.perf_counter()
            epoch_duration = epoch_end_time - epoch_start_time
            print(f"Epoch {epoch} took {epoch_duration:.2f}s.")

            if (
                self.params.val_epoch_freq >= 0
                and (epoch + 1) % self.params.val_epoch_freq == 0
            ):
                validate(epoch + 1, total_iters)

            if (
                self.params.save_checkpoints
                and (epoch + 1) % self.params.save_epoch_freq == 0
            ):
                file_name = pjoin(self.checkpoints_dir, f"{epoch + 1:04d}.pth")
                self.model.save(file_name)

            if (
                self.params.save_intermediate_depth_streams_freq > 0
                and (epoch + 1) % self.params.save_intermediate_depth_streams_freq == 0
            ):
                self.save_depth(frames=self.frames)

            if (
                self.params.recon == "i3d"
                and (epoch + 1) % self.params.pose_opt_freq == 0
            ):
                if self.params.save_intermediate_depth_streams_freq > 0:
                    # Create new depth stream for optimized poses.
                    epoch_str = f"e{epoch:04d}_opt"
                    self.depth_dir = os.path.join(self.out_dir, f"depth_{epoch_str}")
                    pose_optimizer.duplicate_last_depth_stream(
                        epoch_str, self.depth_dir
                    )

                # Pose optimization with depth/spatial deformation
                pose_opt_start_time = time.perf_counter()

                pose_optimizer.optimize_poses()
                dataset.update_poses(pose_optimizer.depth_video)

                pose_opt_end_time = time.perf_counter()
                pose_opt_duration = pose_opt_end_time - pose_opt_start_time

                print(f"Complete pose optimization in {pose_opt_duration:.2f}s")

                if (
                    self.params.save_intermediate_depth_streams_freq > 0
                    and (epoch + 1) % self.params.save_intermediate_depth_streams_freq
                    == 0
                ):
                    self.save_depth(frames=self.frames)

            if (
                self.params.save_intermediate_depth_streams_freq > 0
                and (epoch + 1) % self.params.save_intermediate_depth_streams_freq == 0
                and epoch + 1 < self.params.num_epochs
            ):
                # Create depth stream for the next epoch.
                epoch_str = f"e{epoch + 1:04d}"
                self.depth_dir = os.path.join(self.out_dir, f"depth_{epoch_str}")
                pose_optimizer.duplicate_last_depth_stream(epoch_str, self.depth_dir)

        # Validate the last epoch, unless it was just done in the loop above.
        if (
            self.params.val_epoch_freq >= 0
            and self.params.num_epochs % self.params.val_epoch_freq != 0
        ):
            validate(epoch=self.params.num_epochs, niters=total_iters)

        if self.params.post_filter:
            pose_optimizer.filter_depth(self.params.filter_radius)

        print("Finished Filtering.")