in bench_cluster/communication/p2p.py [0:0]
def timed_p2p(input, start_event, end_event, warmups, trials, async_op, bw_unit, raw):
world_size = dist.get_world_size()
rank = dist.get_rank()
sync_all()
# Warmups, establish connections, etc.
for _ in range(warmups):
for i in range(world_size):
if i != rank:
if async_op:
if rank < i:
dist.isend(input, i)
else:
dist.irecv(input, src=i)
else:
if rank < i:
dist.send(input, i)
else:
dist.recv(input, src=i)
sync_all()
# time the actual comm op trials times and average it
start_event.record()
for _ in range(trials):
for i in range(world_size):
if i != rank:
if async_op:
if rank < i:
dist.isend(input, i)
else:
dist.irecv(input, src=i)
else:
if rank < i:
dist.send(input, i)
else:
dist.recv(input, src=i)
end_event.record()
sync_all()
duration = start_event.elapsed_time(end_event) / 1000
# maintain and clean performance data
avg_duration = duration / trials
size = input.element_size() * input.nelement()
n = world_size
tput, busbw = get_bw(bw_unit, 'p2p', size * (n - 1), avg_duration) # Multiply size by (n-1) as each process communicates with all others
tput_str, busbw_str, duration_str = get_metric_strings(raw, tput, busbw, avg_duration)
desc = f'{input.nelement()}x{input.element_size()}'
if not raw:
size = convert_size(size)
print_rank_0(f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")