in sample_workloads/lit-gpt-demo/utilities/monitor_collectives.py [0:0]
def traced_all_gather(tensor_list, tensor, group=None, async_op=False):
"""Intercepts invocations of torch.distributed.all_gather.
Let n := [Group Size]
Calculate [Ring-B/W] = (n-1)/n * [Message Size] / [Kernel Time]
Assuming equal tensor sizes.
"""
if _should_rank_record_comm(group):
message_size = functools.reduce(
lambda size, x: size + x.nelement() * x.element_size(), tensor_list, 0)
_emit_call_description('all_gather', message_size, group)
return torch.distributed.untraced_all_gather(
tensor_list, tensor, group, async_op)