in bench_cluster/communication/utils.py [0:0]
def max_numel(comm_op, dtype, mem_factor, local_rank):
dtype_size = _element_size(dtype)
max_memory_per_gpu = torch.cuda.get_device_properties(local_rank).total_memory * mem_factor
if comm_op == 'all_reduce' or comm_op == 'p2p' or comm_op == 'broadcast':
elements_per_gpu = int(max_memory_per_gpu // dtype_size)
elif comm_op == 'all_gather':
# all_gather performance is lower for non-powers of two, and the output buffer size scales with world size
# Therefore, divide by world size and round down to nearest power of 2
elements_per_gpu = int(max_memory_per_gpu // dtype_size // dist.get_world_size())
elements_per_gpu = int(pow(2, int(math.log(elements_per_gpu, 2))))
elif comm_op == 'all_to_all':
# Number of elements must be divisible by world_size
# all_to_all performance is lower for non-powers of two. Round down like all_gather.
elements_per_gpu = int(max_memory_per_gpu // dtype_size)
elements_per_gpu = int(dist.get_world_size() * round(elements_per_gpu / dist.get_world_size()))
elements_per_gpu = int(pow(2, int(math.log(elements_per_gpu, 2))))
else:
print(f"This communication operation: {comm_op} is not supported yet")
exit(0)
return elements_per_gpu