def run_benchmark()

in dataflux_pytorch/benchmark/checkpointing/simulated/llama2.py [0:0]


def run_benchmark(world_size: int, project: str, filepath: str,
                  sample_count: int, use_fsspec: bool,
                  model_parameter_size: str, optimizer: str) -> None:

    if os.environ.get("COORDINATOR_ADDRESS"):
        init_processes()
    rank = int(os.environ.get("NODE_RANK", 0))

    benchmark_strategy = BenchmarkStrategy(project=project,
                                           path=filepath,
                                           use_fsspec=use_fsspec)
    print(f'Constructing state dict for LLAMA2 {model_parameter_size}')
    state_dict = create_llama2_state_dict(world_size=world_size,
                                          rank=rank,
                                          parameters=model_parameter_size,
                                          optimizer=optimizer,
                                          empty=False)

    dist.barrier()
    save_checkpoint_times = time_checkpoint_operation(
        benchmark_strategy, state_dict, filepath, sample_count, 'save', rank,
        world_size, model_parameter_size, optimizer)

    load_checkpoint_times = time_checkpoint_operation(
        benchmark_strategy, state_dict, filepath, sample_count, 'load', rank,
        world_size, model_parameter_size, optimizer)

    # Calculate tensor size for current rank
    local_tensor_size = sum(v.element_size() * v.numel()
                            for v in state_dict.values())
    # Gather sizes from all ranks
    sizes = [0] * world_size
    dist.all_gather_object(sizes, local_tensor_size)
    total_size_bytes = sum(sizes)

    if rank == 0:
        print(f"Time taken to save checkpoint:\
                {statistics.mean(save_checkpoint_times):.4f} seconds (stdev {statistics.stdev(save_checkpoint_times):.4f})"
              )
        print(f"All save times: {save_checkpoint_times}")
        print(f"Time taken to load checkpoint:\
                    {statistics.mean(load_checkpoint_times):.4f} seconds (stdev {statistics.stdev(load_checkpoint_times):.4f})"
              )
        print(f"All load times: {load_checkpoint_times}")

        print(f"Size of distributed tensors per rank:")
        for r, size in enumerate(sizes):
            print(f"Rank {r}: {format_size(size)}")
        print(f"Total size of all tensors: {format_size(total_size_bytes)}")

        print("######################")
    cleanup()