def main()

in projects/deep_video_compression/train.py [0:0]


def main(cfg: DictConfig):
    root = Path(cfg.logging.save_root)  # if relative, uses Hydra outputs dir
    model = DVC(**cfg.model)
    logger = WandbLogger(
        save_dir=str(root.absolute()),
        project="DVC",
        config=OmegaConf.to_container(cfg),  # saves the Hydra config to wandb
    )
    data = Vimeo90kSeptupletLightning(
        frames_per_group=7,
        **cfg.data,
        pin_memory=cfg.ngpu != 0,
    )

    # set up image logging
    rng = np.random.default_rng(cfg.logging.image_seed)
    data.setup()
    val_dataset = data.val_dataset
    log_image_indices = rng.permutation(len(val_dataset))[: cfg.logging.num_log_images]
    log_images = torch.stack([val_dataset[ind] for ind in log_image_indices])
    image_logger = WandbImageCallback(log_images)

    # run through each stage and optimize
    for stage in sorted(cfg.training_stages.keys()):
        model = run_training_stage(stage, root, model, data, logger, image_logger, cfg)