in throughput/all_reduce_bench.py [0:0]
def run(local_rank):
hostname = socket.gethostname()
id = f"{hostname}:{local_rank}"
global_rank = dist.get_rank()
printflock(f"{id} data size: {M*N*4/1e9} GB")
mat = torch.rand(N, M, dtype=torch.float32).cuda(local_rank)
for i in range(TRIALS):
dist.barrier()
if global_rank == 0:
print(f"\n\n\n-----------trial-{i}----------------")
timed_allreduce(mat, id)