in dataflux_pytorch/benchmark/checkpointing/simulated/llama2.py [0:0]
def run_benchmark(world_size: int, project: str, filepath: str,
sample_count: int, use_fsspec: bool,
model_parameter_size: str, optimizer: str) -> None:
if os.environ.get("COORDINATOR_ADDRESS"):
init_processes()
rank = int(os.environ.get("NODE_RANK", 0))
benchmark_strategy = BenchmarkStrategy(project=project,
path=filepath,
use_fsspec=use_fsspec)
print(f'Constructing state dict for LLAMA2 {model_parameter_size}')
state_dict = create_llama2_state_dict(world_size=world_size,
rank=rank,
parameters=model_parameter_size,
optimizer=optimizer,
empty=False)
dist.barrier()
save_checkpoint_times = time_checkpoint_operation(
benchmark_strategy, state_dict, filepath, sample_count, 'save', rank,
world_size, model_parameter_size, optimizer)
load_checkpoint_times = time_checkpoint_operation(
benchmark_strategy, state_dict, filepath, sample_count, 'load', rank,
world_size, model_parameter_size, optimizer)
# Calculate tensor size for current rank
local_tensor_size = sum(v.element_size() * v.numel()
for v in state_dict.values())
# Gather sizes from all ranks
sizes = [0] * world_size
dist.all_gather_object(sizes, local_tensor_size)
total_size_bytes = sum(sizes)
if rank == 0:
print(f"Time taken to save checkpoint:\
{statistics.mean(save_checkpoint_times):.4f} seconds (stdev {statistics.stdev(save_checkpoint_times):.4f})"
)
print(f"All save times: {save_checkpoint_times}")
print(f"Time taken to load checkpoint:\
{statistics.mean(load_checkpoint_times):.4f} seconds (stdev {statistics.stdev(load_checkpoint_times):.4f})"
)
print(f"All load times: {load_checkpoint_times}")
print(f"Size of distributed tensors per rank:")
for r, size in enumerate(sizes):
print(f"Rank {r}: {format_size(size)}")
print(f"Total size of all tensors: {format_size(total_size_bytes)}")
print("######################")
cleanup()