def run_benchmark()

in dataflux_pytorch/benchmark/checkpointing/simulated/benchmark.py [0:0]


def run_benchmark(world_size: int, layer_size: int, project: str,
                  filepath: str, padding_size: int, sample_count: int,
                  use_fsspec: bool) -> None:

    if os.environ.get("COORDINATOR_ADDRESS"):
        init_processes()
    rank = int(os.environ.get("NODE_RANK", 0))

    benchmark_strategy = BenchmarkStrategy(project=project,
                                           path=filepath,
                                           use_fsspec=use_fsspec)
    # According to `create_default_local_load_plan` https://github.com/pytorch/pytorch/blob/v2.3.1/torch/distributed/checkpoint/default_planner.py#L227
    # each key will be read only once from the state_dict, hence assigning different names to different tensor will force the load function to only read
    # tensor shard corresponding to given node.
    state_dict = dict()
    for i in range(padding_size):
        if i % world_size == rank:
            state_dict[f'dummy_tensor_{i}'] = torch.randn(layer_size, 1000)

    # Wait until the state_dict is populated properly accross all the nodes.
    dist.barrier()

    save_checkpoint_times = time_checkpoint_operation(benchmark_strategy,
                                                      state_dict, filepath,
                                                      sample_count, 'save',
                                                      rank, world_size,
                                                      padding_size, layer_size)

    load_checkpoint_times = time_checkpoint_operation(benchmark_strategy,
                                                      state_dict, filepath,
                                                      sample_count, 'load',
                                                      rank, world_size,
                                                      padding_size, layer_size)

    if rank == 0:
        print(f"Time taken to save checkpoint:\
                {statistics.mean(save_checkpoint_times):.4f} seconds (stdev {statistics.stdev(save_checkpoint_times):.4f})"
              )
        print(f"All save times: {save_checkpoint_times}")
        print(f"Time taken to load checkpoint:\
                 {statistics.mean(load_checkpoint_times):.4f} seconds (stdev {statistics.stdev(load_checkpoint_times):.4f})"
              )
        print(f"All load times: {load_checkpoint_times}")

        tensor_size_per_instance = 1000 * layer_size * state_dict[
            f'dummy_tensor_0'].element_size()
        tensors_per_rank = padding_size // world_size
        total_size_bytes = tensors_per_rank * tensor_size_per_instance * world_size
        print(f"Size of distributed tensors (rank {rank}):\
                 {format_size(tensors_per_rank * tensor_size_per_instance)}")
        print(f"Total size of all tensors:\
                 {format_size(total_size_bytes)}")
        print("######################")

    cleanup()