in sample_workloads/megatron-gke/docker/monitor_collectives.py [0:0]
def traced_broadcast(tensor, src, group=None, async_op=False):
"""Intercepts invocations of torch.distributed.broadcast.
Calculate [Ring-B/W] = [Message Size]/[Kernel Time] for large [Message Size]
https://images.nvidia.com/events/sc15/pdfs/NCCL-Woolley.pdf
Args:
tensor: Passed to torch.distributed.broadcast
src: Passed to torch.distributed.broadcast
group: Passed to torch.distributed.broadcast
async_op: Passed to torch.distributed.broadcast
Returns:
Output of torch.distributed.broadcast
"""
if _should_rank_record_comm(group, root_rank=src):
message_size = tensor.nelement() * tensor.element_size()
_emit_call_description('broadcast', message_size, group, root_rank=src)
return torch.distributed.untraced_broadcast(
tensor, src, group, async_op)