in dataflux_pytorch/benchmark/checkpointing/simulated/benchmark.py [0:0]
def run_benchmark(world_size: int, layer_size: int, project: str,
filepath: str, padding_size: int, sample_count: int,
use_fsspec: bool) -> None:
if os.environ.get("COORDINATOR_ADDRESS"):
init_processes()
rank = int(os.environ.get("NODE_RANK", 0))
benchmark_strategy = BenchmarkStrategy(project=project,
path=filepath,
use_fsspec=use_fsspec)
# According to `create_default_local_load_plan` https://github.com/pytorch/pytorch/blob/v2.3.1/torch/distributed/checkpoint/default_planner.py#L227
# each key will be read only once from the state_dict, hence assigning different names to different tensor will force the load function to only read
# tensor shard corresponding to given node.
state_dict = dict()
for i in range(padding_size):
if i % world_size == rank:
state_dict[f'dummy_tensor_{i}'] = torch.randn(layer_size, 1000)
# Wait until the state_dict is populated properly accross all the nodes.
dist.barrier()
save_checkpoint_times = time_checkpoint_operation(benchmark_strategy,
state_dict, filepath,
sample_count, 'save',
rank, world_size,
padding_size, layer_size)
load_checkpoint_times = time_checkpoint_operation(benchmark_strategy,
state_dict, filepath,
sample_count, 'load',
rank, world_size,
padding_size, layer_size)
if rank == 0:
print(f"Time taken to save checkpoint:\
{statistics.mean(save_checkpoint_times):.4f} seconds (stdev {statistics.stdev(save_checkpoint_times):.4f})"
)
print(f"All save times: {save_checkpoint_times}")
print(f"Time taken to load checkpoint:\
{statistics.mean(load_checkpoint_times):.4f} seconds (stdev {statistics.stdev(load_checkpoint_times):.4f})"
)
print(f"All load times: {load_checkpoint_times}")
tensor_size_per_instance = 1000 * layer_size * state_dict[
f'dummy_tensor_0'].element_size()
tensors_per_rank = padding_size // world_size
total_size_bytes = tensors_per_rank * tensor_size_per_instance * world_size
print(f"Size of distributed tensors (rank {rank}):\
{format_size(tensors_per_rank * tensor_size_per_instance)}")
print(f"Total size of all tensors:\
{format_size(total_size_bytes)}")
print("######################")
cleanup()