in log_analyzer/utils.py [0:0]
def calc_bw_log(comm_type: CommType, size, duration,group_size): # size: Bytes; duration: ms
n = group_size if group_size else 1
duration /= 1000
if comm_type in [CommType.all_gather, CommType.reduce_scatter]:
# size *= n
tput = size / duration
busbw = (size / duration) * ((n - 1) / n)
elif comm_type == CommType.all_reduce:
tput = size / duration
busbw = (size / duration) * (2 * (n - 1) / n)
elif comm_type in [CommType.barrier, CommType.computation]:
return 0, 0
else: # [CommType.broadcast, CommType.reduce, "gather", "scatter", "isend", "irecv"]
tput = size / duration
busbw = tput
tput /= 1024*1024*1024
busbw /= 1024*1024*1024
tput = round(tput, 2)
busbw = round(busbw, 2)
return tput, busbw