def run_training_stage()

in projects/deep_video_compression/train.py [0:0]


def run_training_stage(stage, root, model, data, logger, image_logger, cfg):
    """Run a single training stage based on the stage config."""
    print(f"training stage: {stage}")
    stage_cfg = cfg.training_stages[stage]
    if stage_cfg.save_dir is None:
        save_dir = root / stage
    else:
        save_dir = Path(stage_cfg.save_dir)

    if (
        not cfg.checkpoint.overwrite
        and not cfg.checkpoint.resume_training
        and len(list(save_dir.glob("*.ckpt"))) > 0
    ):
        raise RuntimeError(
            "Checkpoints detected in save directory: set resume_training=True "
            "to restore trainer state from these checkpoints, "
            "or set overwrite=True to ignore them."
        )

    save_dir.mkdir(exist_ok=True, parents=True)
    last_checkpoint = save_dir / "last.ckpt"
    if not last_checkpoint.exists() or cfg.checkpoint.overwrite is True:
        last_checkpoint = None

    lightning_model = DvcModule(model, **merge_configs(cfg.module, stage_cfg.module))

    trainer = pl.Trainer(
        **merge_configs(cfg.trainer, stage_cfg.trainer),
        logger=logger,
        callbacks=[
            LearningRateMonitor(),
            ModelCheckpoint(dirpath=save_dir, **cfg.checkpoint.model_checkpoint),
            image_logger,
        ],
        resume_from_checkpoint=last_checkpoint,
    )

    trainer.fit(lightning_model, datamodule=data)

    return lightning_model.recompose_model(model)