def main()

in projects/nerf/train_nerf.py [0:0]


def main(cfg: DictConfig):

    # Set the relevant seeds for reproducibility.
    np.random.seed(cfg.seed)
    torch.manual_seed(cfg.seed)

    # Device on which to run.
    if torch.cuda.is_available():
        device = "cuda"
    else:
        warnings.warn(
            "Please note that although executing on CPU is supported,"
            + "the training is unlikely to finish in reasonable time."
        )
        device = "cpu"

    # Initialize the Radiance Field model.
    model = RadianceFieldRenderer(
        image_size=cfg.data.image_size,
        n_pts_per_ray=cfg.raysampler.n_pts_per_ray,
        n_pts_per_ray_fine=cfg.raysampler.n_pts_per_ray,
        n_rays_per_image=cfg.raysampler.n_rays_per_image,
        min_depth=cfg.raysampler.min_depth,
        max_depth=cfg.raysampler.max_depth,
        stratified=cfg.raysampler.stratified,
        stratified_test=cfg.raysampler.stratified_test,
        chunk_size_test=cfg.raysampler.chunk_size_test,
        n_harmonic_functions_xyz=cfg.implicit_function.n_harmonic_functions_xyz,
        n_harmonic_functions_dir=cfg.implicit_function.n_harmonic_functions_dir,
        n_hidden_neurons_xyz=cfg.implicit_function.n_hidden_neurons_xyz,
        n_hidden_neurons_dir=cfg.implicit_function.n_hidden_neurons_dir,
        n_layers_xyz=cfg.implicit_function.n_layers_xyz,
        density_noise_std=cfg.implicit_function.density_noise_std,
        visualization=cfg.visualization.visdom,
    )

    # Move the model to the relevant device.
    model.to(device)

    # Init stats to None before loading.
    stats = None
    optimizer_state_dict = None
    start_epoch = 0

    checkpoint_path = os.path.join(hydra.utils.get_original_cwd(), cfg.checkpoint_path)
    if len(cfg.checkpoint_path) > 0:
        # Make the root of the experiment directory.
        checkpoint_dir = os.path.split(checkpoint_path)[0]
        os.makedirs(checkpoint_dir, exist_ok=True)

        # Resume training if requested.
        if cfg.resume and os.path.isfile(checkpoint_path):
            print(f"Resuming from checkpoint {checkpoint_path}.")
            loaded_data = torch.load(checkpoint_path)
            model.load_state_dict(loaded_data["model"])
            stats = pickle.loads(loaded_data["stats"])
            print(f"   => resuming from epoch {stats.epoch}.")
            optimizer_state_dict = loaded_data["optimizer"]
            start_epoch = stats.epoch

    # Initialize the optimizer.
    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=cfg.optimizer.lr,
    )

    # Load the optimizer state dict in case we are resuming.
    if optimizer_state_dict is not None:
        optimizer.load_state_dict(optimizer_state_dict)
        optimizer.last_epoch = start_epoch

    # Init the stats object.
    if stats is None:
        stats = Stats(
            ["loss", "mse_coarse", "mse_fine", "psnr_coarse", "psnr_fine", "sec/it"],
        )

    # Learning rate scheduler setup.

    # Following the original code, we use exponential decay of the
    # learning rate: current_lr = base_lr * gamma ** (epoch / step_size)
    def lr_lambda(epoch):
        return cfg.optimizer.lr_scheduler_gamma ** (
            epoch / cfg.optimizer.lr_scheduler_step_size
        )

    # The learning rate scheduling is implemented with LambdaLR PyTorch scheduler.
    lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
        optimizer, lr_lambda, last_epoch=start_epoch - 1, verbose=False
    )

    # Initialize the cache for storing variables needed for visualization.
    visuals_cache = collections.deque(maxlen=cfg.visualization.history_size)

    # Init the visualization visdom env.
    if cfg.visualization.visdom:
        viz = Visdom(
            server=cfg.visualization.visdom_server,
            port=cfg.visualization.visdom_port,
            use_incoming_socket=False,
        )
    else:
        viz = None

    # Load the training/validation data.
    train_dataset, val_dataset, _ = get_nerf_datasets(
        dataset_name=cfg.data.dataset_name,
        image_size=cfg.data.image_size,
    )

    if cfg.data.precache_rays:
        # Precache the projection rays.
        model.eval()
        with torch.no_grad():
            for dataset in (train_dataset, val_dataset):
                cache_cameras = [e["camera"].to(device) for e in dataset]
                cache_camera_hashes = [e["camera_idx"] for e in dataset]
                model.precache_rays(cache_cameras, cache_camera_hashes)

    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        batch_size=1,
        shuffle=True,
        num_workers=0,
        collate_fn=trivial_collate,
    )

    # The validation dataloader is just an endless stream of random samples.
    val_dataloader = torch.utils.data.DataLoader(
        val_dataset,
        batch_size=1,
        num_workers=0,
        collate_fn=trivial_collate,
        sampler=torch.utils.data.RandomSampler(
            val_dataset,
            replacement=True,
            num_samples=cfg.optimizer.max_epochs,
        ),
    )

    # Set the model to the training mode.
    model.train()

    # Run the main training loop.
    for epoch in range(start_epoch, cfg.optimizer.max_epochs):
        stats.new_epoch()  # Init a new epoch.
        for iteration, batch in enumerate(train_dataloader):
            image, camera, camera_idx = batch[0].values()
            image = image.to(device)
            camera = camera.to(device)

            optimizer.zero_grad()

            # Run the forward pass of the model.
            nerf_out, metrics = model(
                camera_idx if cfg.data.precache_rays else None,
                camera,
                image,
            )

            # The loss is a sum of coarse and fine MSEs
            loss = metrics["mse_coarse"] + metrics["mse_fine"]

            # Take the training step.
            loss.backward()
            optimizer.step()

            # Update stats with the current metrics.
            stats.update(
                {"loss": float(loss), **metrics},
                stat_set="train",
            )

            if iteration % cfg.stats_print_interval == 0:
                stats.print(stat_set="train")

            # Update the visualization cache.
            if viz is not None:
                visuals_cache.append(
                    {
                        "camera": camera.cpu(),
                        "camera_idx": camera_idx,
                        "image": image.cpu().detach(),
                        "rgb_fine": nerf_out["rgb_fine"].cpu().detach(),
                        "rgb_coarse": nerf_out["rgb_coarse"].cpu().detach(),
                        "rgb_gt": nerf_out["rgb_gt"].cpu().detach(),
                        "coarse_ray_bundle": nerf_out["coarse_ray_bundle"],
                    }
                )

        # Adjust the learning rate.
        lr_scheduler.step()

        # Validation
        if epoch % cfg.validation_epoch_interval == 0 and epoch > 0:

            # Sample a validation camera/image.
            val_batch = next(val_dataloader.__iter__())
            val_image, val_camera, camera_idx = val_batch[0].values()
            val_image = val_image.to(device)
            val_camera = val_camera.to(device)

            # Activate eval mode of the model (lets us do a full rendering pass).
            model.eval()
            with torch.no_grad():
                val_nerf_out, val_metrics = model(
                    camera_idx if cfg.data.precache_rays else None,
                    val_camera,
                    val_image,
                )

            # Update stats with the validation metrics.
            stats.update(val_metrics, stat_set="val")
            stats.print(stat_set="val")

            if viz is not None:
                # Plot that loss curves into visdom.
                stats.plot_stats(
                    viz=viz,
                    visdom_env=cfg.visualization.visdom_env,
                    plot_file=None,
                )
                # Visualize the intermediate results.
                visualize_nerf_outputs(
                    val_nerf_out, visuals_cache, viz, cfg.visualization.visdom_env
                )

            # Set the model back to train mode.
            model.train()

        # Checkpoint.
        if (
            epoch % cfg.checkpoint_epoch_interval == 0
            and len(cfg.checkpoint_path) > 0
            and epoch > 0
        ):
            print(f"Storing checkpoint {checkpoint_path}.")
            data_to_store = {
                "model": model.state_dict(),
                "optimizer": optimizer.state_dict(),
                "stats": pickle.dumps(stats),
            }
            torch.save(data_to_store, checkpoint_path)