in bench_cluster/communication/all_reduce.py [0:0]
def timed_all_reduce(input, start_event, end_event, warmups, trials, async_op, bw_unit, raw):
sync_all()
# Warmups, establish connections, etc.
for i in range(warmups):
dist.all_reduce(input, async_op=async_op)
sync_all()
# time the actual comm op trials times and average it
start_event.record()
for i in range(trials):
dist.all_reduce(input, async_op=async_op)
end_event.record()
sync_all()
duration = start_event.elapsed_time(end_event) / 1000
# maintain and clean performance data
avg_duration = duration / trials
size = input.element_size() * input.nelement()
n = dist.get_world_size()
tput, busbw = get_bw(bw_unit, 'all_reduce', size, avg_duration)
tput_str, busbw_str, duration_str = get_metric_strings(raw, tput, busbw, avg_duration)
desc = f'{input.nelement()}x{input.element_size()}'
if not raw:
size = convert_size(size)
print_rank_0(f"{size:<20} {desc:25s} {duration_str:20s} {tput_str:20s} {busbw_str:20s}")