def main()

in dataflux_pytorch/benchmark/checkpointing/singlenode/train.py [0:0]


def main():
    """Checkpoints a PyTorch Ligthning demo model to GCS using gcsfs or DatafluxLightningCheckpoint.

    This function utilizes PyTorch Lightning to checkpoint the WikiText2 dataset. It
    takes in information regarding the gcs location to save the checkpoints, the type of
    checkpoint, and other configuration variables. Default this function runs on
    gcsfs to write PyTorch Ligthtning checkpoints, TorchCheckpointIO. If dataflux_ckpt
    is enabled the Trainer will be passed a DatafluxLightningCheckpoint, which is an
    implementation of the CheckpointIO interface, as a plugin.

    Typical usage example:

      Run DatafluxLightningCheckpoint over 10 steps:

      python3 train.py --project=my-project --ckpt_dir_path=gs://bucket-name/path/to/dir/ --save_only_latest --layers=1000 --steps=10

      Run gcsfs over 10 steps:

      python3 train.py --project=my-project --ckpt_dir_path=gs://bucket-name/path/to/dir/ --layers=1000 --steps=10 --checkpoint=no-dataflux

    """
    args = parse_args()
    if args.steps < 1:
        raise ValueError("Steps need to greater than 0.")

    dataset = WikiText2()
    dataloader = DataLoader(dataset, num_workers=1)
    model = LightningTransformer(vocab_size=dataset.vocab_size,
                                 nlayers=args.layers)

    # Checkpoint strategy selection.
    if args.checkpoint == DF_LIGHTNING:
        print("Running with Dataflux Lightning...")
        ckpt = DatafluxLightningCheckpoint(project_name=args.project)
    elif args.checkpoint == ASYNC_DF_LIGHTNING:
        print("Running with Dataflux Lightning AsyncCheckpointIO...")
        ckpt = BenchmarkDatafluxLightningAsyncCheckpoint(
            project_name=args.project)
    elif args.checkpoint == NO_DF:
        print("Running Lightning without Dataflux...")
        ckpt = TorchCheckpointIO()
    elif args.checkpoint == NO_LIGHTNING:
        print("Running without Lightning...")
        # Dataflux Lightning Checkpoint is stil used here to provide an argument
        # during construction of the trainer.
        ckpt = DatafluxLightningCheckpoint(project_name=args.project)
    else:
        raise ValueError("Invalid choice for --checkpoint")

    # Save once per step, and if `save_only_latest`, replace the last checkpoint each time.
    # Replacing is implemented by saving the new checkpoint, and then deleting the previous one.
    # If `save_only_latest` is False, a new checkpoint is created for each step.
    checkpoint_callback = ModelCheckpoint(
        save_top_k=1 if args.save_only_latest else -1,
        every_n_train_steps=1,
        filename="checkpoint-{epoch:02d}-{step:02d}",
        enable_version_counter=True,
    )
    trainer = Trainer(
        default_root_dir=args.ckpt_dir_path,
        plugins=[ckpt],
        callbacks=[checkpoint_callback],
        min_epochs=4,
        max_epochs=5,
        max_steps=1,
        accelerator="cpu",
    )
    trainer.fit(model, dataloader)

    if args.checkpoint == NO_LIGHTNING:
        bucket_name, ckpt_path = path_utils.parse_gcs_path(args.ckpt_dir_path)
        ckpt = dataflux_checkpoint.DatafluxCheckpoint(
            project_name=args.project, bucket_name=bucket_name)
    # Measure save checkpoint.
    start = time.time()
    for i in range(args.steps):
        if args.checkpoint == NO_LIGHTNING:
            CKPT_PATH = ckpt_path + f'checkpoint_{i}.ckpt' if ckpt_path else f'checkpoints/checkpoint_{i}.ckpt'
            with ckpt.writer(CKPT_PATH) as writer:
                torch.save(model.state_dict(), writer)
        else:
            trainer.save_checkpoint(
                os.path.join(args.ckpt_dir_path, f'ckpt_{i}.ckpt'))
    end = time.time()
    # Command to clear kernel cache only works on MacOs and Linux.
    if args.clear_kernel_cache and sys.platform in [
            "darwin", "linux", "linux2", "linux3"
    ]:
        print("Clearing kernel cache...")
        os.system("sync && sudo sysctl -w vm.drop_caches=3")
    print("Average time to save one checkpoint: " +
          str((end - start) / args.steps) + " seconds")

    # If using async, call finalize to shut down the threadpool executor.
    if args.checkpoint == ASYNC_DF_LIGHTNING:
        ckpt.finalize()

    # Measure load checkpoint.
    start = time.time()
    for i in range(args.steps):
        if args.checkpoint == NO_LIGHTNING:
            with ckpt.reader(CKPT_PATH) as reader:
                read_state_dict = torch.load(reader)
        else:
            _ = ckpt.load_checkpoint(
                os.path.join(args.ckpt_dir_path, f'ckpt_{i}.ckpt'))
    end = time.time()
    print("Average time to load one checkpoint: " +
          str((end - start) / args.steps) + " seconds")