in dataflux_pytorch/benchmark/checkpointing/multinode/train_async_save.py [0:0]
def main():
project = os.getenv("PROJECT")
num_nodes = int(os.environ.get("NUM_NODES", 1))
devices = os.environ.get("NUM_DEVICES", 'auto')
use_async = bool(os.getenv("USE_ASYNC", "1") == "1")
ckpt_dir_path = os.getenv("CKPT_DIR_PATH")
num_layers = int(os.environ.get("NUM_LAYERS", 10))
min_epochs = int(os.environ.get("MIN_EPOCHS", 4))
max_epochs = int(os.environ.get("MAX_EPOCHS", 5))
max_steps = int(os.environ.get("MAX_STEPS", 3))
steps_per_save = int(os.environ.get("STEPS_PER_SAVE", 1))
run_count = int(os.getenv("NUM_SAVE_CALLS", 1))
simulated_workload_duration = int(
os.getenv("SIMULATED_WORKLOAD_DURATION", 1))
rank = 0
if os.environ.get("COORDINATOR_ADDRESS"):
init_processes()
dist.init_process_group("gloo",
rank=int(os.environ.get("NODE_RANK", 0)),
world_size=num_nodes)
torch.cuda.empty_cache()
dataset = WikiText2()
dataloader = DataLoader(dataset, num_workers=1)
trainer_fit_times = []
for i in range(run_count):
model = DemoTransformer(
vocab_size=dataset.vocab_size,
nlayers=num_layers,
simulated_workload_duration=simulated_workload_duration,
)
checkpoint_callback = ModelCheckpoint(
save_top_k=-1,
every_n_train_steps=steps_per_save,
filename="checkpoint-{epoch:02d}-{step:02d}",
enable_version_counter=True,
dirpath=ckpt_dir_path,
)
strategy = DatafluxFSDPStrategy(
project_name=project,
storage_client=None,
state_dict_type="sharded",
use_orig_params=False,
use_async=use_async,
)
trainer = Trainer(
default_root_dir=ckpt_dir_path,
plugins=[],
callbacks=[checkpoint_callback],
max_steps=max_steps,
min_epochs=min_epochs,
max_epochs=max_epochs,
accelerator="gpu",
strategy=strategy,
devices=devices,
num_nodes=num_nodes,
enable_progress_bar=False,
)
init_start_event = torch.cuda.Event(enable_timing=True)
init_end_event = torch.cuda.Event(enable_timing=True)
init_start_event.record()
trainer.fit(model, dataloader)
init_end_event.record()
total_time = init_start_event.elapsed_time(init_end_event) / 1000
rank = int(os.environ.get("NODE_RANK", 0))
print(
f"Individual run {i+1} of {run_count} trainer.fit() #{rank} took {total_time} seconds."
)
trainer_fit_times.append(total_time)
# All runs complete.
if int(os.environ.get("NODE_RANK", 0)) == 0:
avg_trainer_fit_time = statistics.mean(trainer_fit_times)
avg_trainer_fit_time_str = str(avg_trainer_fit_time) + " seconds"
print("##################################")
print("Average time for trainer.fit(): " + avg_trainer_fit_time_str)
print("##################################")
print(f"All trainer.fit() times: {trainer_fit_times}")
# Cleanup.
torch.distributed.destroy_process_group()