def main()

in train_gsn.py [0:0]


def main(opt):
    # configure dataset so that each epoch has 1k iterations
    opt.data_config.samples_per_epoch = opt.data_config.batch_size * torch.cuda.device_count() * 1000
    data_module = build_dataloader(opt.data_config)

    # build model
    opt.model_config.params.img_res = opt.data_config.img_res
    gsn = instantiate_from_config(opt.model_config)
    # add config to the model so it can be saved during checkpointing
    gsn.opt = opt

    # get real camera trajectories from dataset to sample during training
    real_Rts = data_module.train_loader.dataset.get_trajectory_Rt()
    trajectory_sampler = TrajectorySampler(real_Rts=real_Rts, mode=opt.model_config.params.trajectory_mode)
    gsn.set_trajectory_sampler(trajectory_sampler=trajectory_sampler)

    if opt.resume_from_path:
        checkpoint = torch.load(opt.resume_from_path)['state_dict']

        # get rid of all the inception params which are leftover from FID metric
        keys_for_deletion = []
        for key in checkpoint.keys():
            if 'fid' in key:
                keys_for_deletion.append(key)

        for key in keys_for_deletion:
            del checkpoint[key]

        gsn.load_state_dict(checkpoint, strict=True)
        print('Resuming from checkpoint at {}'.format(opt.resume_from_path))

    checkpoint_callback = ModelCheckpoint(
        monitor='metrics/fid',
        save_last=True,
        dirpath=os.path.join(opt.log_dir, 'checkpoints'),
        filename='gsn-model-best-fid',
        save_top_k=1,
        mode='min',
    )

    voxel_res = opt.model_config.params.voxel_res
    voxel_size = opt.model_config.params.voxel_size
    viz_callback = GSNVizCallback(opt.log_dir, voxel_res=voxel_res, voxel_size=voxel_size)

    callback_list = [viz_callback, checkpoint_callback]

    logger = TensorBoardLogger(os.path.join(opt.log_dir, 'logs'), name="gsn")

    trainer = pl.Trainer(
        gpus=torch.cuda.device_count(),
        callbacks=callback_list,
        accelerator='ddp',
        num_sanity_val_steps=0,
        check_val_every_n_epoch=opt.eval_freq,
        logger=logger,
        precision=opt.precision,
        max_epochs=opt.n_epochs,
        progress_bar_refresh_rate=1,
    )

    if opt.evaluate:
        trainer.validate(
            gsn,
            val_dataloaders=data_module.val_dataloader(),
        )
    else:
        trainer.fit(
            gsn,
            train_dataloaders=data_module.train_dataloader(),
            val_dataloaders=data_module.val_dataloader(),
        )