in dataflux_pytorch/benchmark/checkpointing/multinode/train.py [0:0]
def main(ckpt_dir_path: str, ckpt_restore_path: str = ""):
args = parse_args()
validate(args)
if os.environ.get("COORDINATOR_ADDRESS"):
init_processes()
torch.cuda.empty_cache()
dataset = WikiText2()
dataloader = DataLoader(dataset, num_workers=1)
model = DemoTransformer(vocab_size=dataset.vocab_size,
nlayers=int(os.environ.get("NUM_LAYERS", 10)))
strategy = get_strategy(args, os.getenv("PROJECT"))
num_save_calls = int(os.environ.get("NUM_SAVE_CALLS", 3))
num_nodes = int(os.environ.get("NUM_NODES", 1))
trainer = Trainer(
enable_checkpointing=False,
logger=False,
default_root_dir=ckpt_dir_path,
plugins=[],
min_epochs=1,
max_epochs=1,
max_steps=1,
accelerator="gpu",
strategy=strategy,
devices=os.environ.get("NUM_DEVICES", 'auto'),
num_nodes=num_nodes,
)
trainer.fit(model, dataloader)
print(f"Saving checkpoint to {ckpt_dir_path} {num_save_calls} times.")
save_checkpoint_times = []
for i in range(num_save_calls):
start = time.time()
trainer.save_checkpoint(
os.path.join(ckpt_dir_path, f'checkpoints/ckpt_{i}.ckpt/'))
end = time.time()
if torch.distributed.get_rank() == 0:
print(
f"Saved checkpoint to {ckpt_dir_path} in {end - start} seconds."
)
save_checkpoint_times.append(end - start)
if torch.distributed.get_rank() == 0:
print(f"Saved checkpoint to {ckpt_dir_path} {num_save_calls} times.")
num_load_calls = int(os.environ.get("NUM_LOAD_CALLS", 3))
load_checkpoint_times = []
if args.save_only:
print("Skipping loads because you set --save_only")
num_load_calls = 0
load_checkpoint_times = [0]
elif args.load_only:
print(f"Copying contents of {ckpt_dir_path} to {ckpt_restore_path}")
copy_bucket_to_local(ckpt_dir_path.removeprefix("gs://"),
os.path.dirname(ckpt_restore_path))
save_checkpoint_times = [0]
for i in range(num_load_calls):
del trainer, strategy, model
model = DemoTransformer(vocab_size=dataset.vocab_size,
nlayers=int(os.environ.get("NUM_LAYERS", 10)))
new_ckpt_dir_path = os.path.join(ckpt_restore_path, f'ckpt_{i}.ckpt/')
strategy = get_strategy(args, os.getenv("PROJECT"))
trainer = Trainer(
enable_checkpointing=False,
logger=False,
default_root_dir=ckpt_dir_path,
plugins=[],
min_epochs=1,
max_epochs=1,
max_steps=1,
accelerator="gpu",
strategy=strategy,
devices=os.environ.get("NUM_DEVICES", 'auto'),
num_nodes=num_nodes,
)
trainer.fit(model, dataloader)
start = time.time()
trainer.strategy.load_checkpoint(new_ckpt_dir_path)
end = time.time()
if torch.distributed.get_rank() == 0:
print(
f"Loaded checkpoint from {new_ckpt_dir_path} in {end - start} seconds"
)
load_checkpoint_times.append(end - start)
if torch.distributed.get_rank() == 0:
avg_load_time = statistics.mean(load_checkpoint_times)
avg_save_time = statistics.mean(save_checkpoint_times)
print_times(args, avg_save_time, avg_load_time)
if not args.load_only:
print(f"All save times: {save_checkpoint_times}")
if not args.save_only:
print(f"All load times: {load_checkpoint_times}")