in bench_cluster/communication/utils.py [0:0]
def get_bw(bw_unit, comm_op, size, duration):
n = dist.get_world_size()
tput = 0
busbw = 0
if comm_op == "all_to_all":
tput = (size / duration)
busbw = (size / duration) * ((n - 1) / n)
elif comm_op == "all_gather":
size *= n
tput = (size / duration)
busbw = (size / duration) * ((n - 1) / n)
elif comm_op == "all_reduce":
tput = (size * 2 / duration)
busbw = (size / duration) * (2 * (n - 1) / n)
elif comm_op == "p2p" or comm_op == "broadcast":
tput = (size / duration)
busbw = tput
else:
print_rank_0("wrong comm_op specified")
exit(0)
if bw_unit == 'Gbps':
tput *= 8
busbw *= 8
return tput, busbw