in projects/deep_video_compression/train.py [0:0]
def run_training_stage(stage, root, model, data, logger, image_logger, cfg):
"""Run a single training stage based on the stage config."""
print(f"training stage: {stage}")
stage_cfg = cfg.training_stages[stage]
if stage_cfg.save_dir is None:
save_dir = root / stage
else:
save_dir = Path(stage_cfg.save_dir)
if (
not cfg.checkpoint.overwrite
and not cfg.checkpoint.resume_training
and len(list(save_dir.glob("*.ckpt"))) > 0
):
raise RuntimeError(
"Checkpoints detected in save directory: set resume_training=True "
"to restore trainer state from these checkpoints, "
"or set overwrite=True to ignore them."
)
save_dir.mkdir(exist_ok=True, parents=True)
last_checkpoint = save_dir / "last.ckpt"
if not last_checkpoint.exists() or cfg.checkpoint.overwrite is True:
last_checkpoint = None
lightning_model = DvcModule(model, **merge_configs(cfg.module, stage_cfg.module))
trainer = pl.Trainer(
**merge_configs(cfg.trainer, stage_cfg.trainer),
logger=logger,
callbacks=[
LearningRateMonitor(),
ModelCheckpoint(dirpath=save_dir, **cfg.checkpoint.model_checkpoint),
image_logger,
],
resume_from_checkpoint=last_checkpoint,
)
trainer.fit(lightning_model, datamodule=data)
return lightning_model.recompose_model(model)