in projects/deep_video_compression/train.py [0:0]
def main(cfg: DictConfig):
root = Path(cfg.logging.save_root) # if relative, uses Hydra outputs dir
model = DVC(**cfg.model)
logger = WandbLogger(
save_dir=str(root.absolute()),
project="DVC",
config=OmegaConf.to_container(cfg), # saves the Hydra config to wandb
)
data = Vimeo90kSeptupletLightning(
frames_per_group=7,
**cfg.data,
pin_memory=cfg.ngpu != 0,
)
# set up image logging
rng = np.random.default_rng(cfg.logging.image_seed)
data.setup()
val_dataset = data.val_dataset
log_image_indices = rng.permutation(len(val_dataset))[: cfg.logging.num_log_images]
log_images = torch.stack([val_dataset[ind] for ind in log_image_indices])
image_logger = WandbImageCallback(log_images)
# run through each stage and optimize
for stage in sorted(cfg.training_stages.keys()):
model = run_training_stage(stage, root, model, data, logger, image_logger, cfg)