def main()

in dataflux_pytorch/benchmark/checkpointing/multinode/train.py [0:0]


def main(ckpt_dir_path: str, ckpt_restore_path: str = ""):
    args = parse_args()
    validate(args)

    if os.environ.get("COORDINATOR_ADDRESS"):
        init_processes()
    torch.cuda.empty_cache()
    dataset = WikiText2()
    dataloader = DataLoader(dataset, num_workers=1)

    model = DemoTransformer(vocab_size=dataset.vocab_size,
                            nlayers=int(os.environ.get("NUM_LAYERS", 10)))
    strategy = get_strategy(args, os.getenv("PROJECT"))
    num_save_calls = int(os.environ.get("NUM_SAVE_CALLS", 3))
    num_nodes = int(os.environ.get("NUM_NODES", 1))

    trainer = Trainer(
        enable_checkpointing=False,
        logger=False,
        default_root_dir=ckpt_dir_path,
        plugins=[],
        min_epochs=1,
        max_epochs=1,
        max_steps=1,
        accelerator="gpu",
        strategy=strategy,
        devices=os.environ.get("NUM_DEVICES", 'auto'),
        num_nodes=num_nodes,
    )
    trainer.fit(model, dataloader)
    print(f"Saving checkpoint to {ckpt_dir_path} {num_save_calls} times.")
    save_checkpoint_times = []
    for i in range(num_save_calls):
        start = time.time()
        trainer.save_checkpoint(
            os.path.join(ckpt_dir_path, f'checkpoints/ckpt_{i}.ckpt/'))
        end = time.time()
        if torch.distributed.get_rank() == 0:
            print(
                f"Saved checkpoint to {ckpt_dir_path} in {end - start} seconds."
            )
        save_checkpoint_times.append(end - start)
    if torch.distributed.get_rank() == 0:
        print(f"Saved checkpoint to {ckpt_dir_path} {num_save_calls} times.")
    num_load_calls = int(os.environ.get("NUM_LOAD_CALLS", 3))
    load_checkpoint_times = []
    if args.save_only:
        print("Skipping loads because you set --save_only")
        num_load_calls = 0
        load_checkpoint_times = [0]
    elif args.load_only:
        print(f"Copying contents of {ckpt_dir_path} to {ckpt_restore_path}")
        copy_bucket_to_local(ckpt_dir_path.removeprefix("gs://"),
                             os.path.dirname(ckpt_restore_path))
        save_checkpoint_times = [0]
    for i in range(num_load_calls):
        del trainer, strategy, model
        model = DemoTransformer(vocab_size=dataset.vocab_size,
                                nlayers=int(os.environ.get("NUM_LAYERS", 10)))
        new_ckpt_dir_path = os.path.join(ckpt_restore_path, f'ckpt_{i}.ckpt/')
        strategy = get_strategy(args, os.getenv("PROJECT"))
        trainer = Trainer(
            enable_checkpointing=False,
            logger=False,
            default_root_dir=ckpt_dir_path,
            plugins=[],
            min_epochs=1,
            max_epochs=1,
            max_steps=1,
            accelerator="gpu",
            strategy=strategy,
            devices=os.environ.get("NUM_DEVICES", 'auto'),
            num_nodes=num_nodes,
        )
        trainer.fit(model, dataloader)
        start = time.time()
        trainer.strategy.load_checkpoint(new_ckpt_dir_path)
        end = time.time()

        if torch.distributed.get_rank() == 0:
            print(
                f"Loaded checkpoint from {new_ckpt_dir_path} in {end - start} seconds"
            )
        load_checkpoint_times.append(end - start)

    if torch.distributed.get_rank() == 0:
        avg_load_time = statistics.mean(load_checkpoint_times)
        avg_save_time = statistics.mean(save_checkpoint_times)
        print_times(args, avg_save_time, avg_load_time)
        if not args.load_only:
            print(f"All save times: {save_checkpoint_times}")
        if not args.save_only:
            print(f"All load times: {load_checkpoint_times}")