def main()

in projects/scale_hyperprior_lightning/train.py [0:0]


def main(cfg: DictConfig):

    save_dir: Path = Path(hydra.utils.get_original_cwd()) / cfg.save_dir

    if (
        not cfg.overwrite
        and not cfg.resume_training
        and len(list(save_dir.glob("*.ckpt"))) > 0
    ):
        raise RuntimeError(
            "Checkpoints detected in save directory: set resume_training=True"
            " to restore trainer state from these checkpoints, or set overwrite=True"
            " to ignore them."
        )

    save_dir.mkdir(exist_ok=True, parents=True)
    last_checkpoint = save_dir / "last.ckpt"

    model = ScaleHyperprior(**cfg.model)
    lightning_model = ScaleHyperpriorLightning(model, **cfg.training_loop)

    data = Vimeo90kSeptupletLightning(**cfg.data, pin_memory=cfg.ngpu != 0)

    loggers = [hydra.utils.instantiate(logger_cfg) for logger_cfg in cfg.loggers]
    trainer = Trainer(
        **cfg.trainer,
        logger=loggers,
        callbacks=[
            LearningRateMonitor(),
            ModelCheckpoint(**cfg.save_model),
        ],
        resume_from_checkpoint=last_checkpoint
        if last_checkpoint.exists() and cfg.resume_training
        else None,
    )

    trainer.fit(lightning_model, datamodule=data)