in sample_workloads/megatron-gke/docker/monitor_collectives.py [0:0]
def traced_barrier(group=None, async_op=False, device_ids=None):
"""Intercepts invocations of torch.distributed.barrier.
Args:
group: Passed to torch.distributed.barrier
async_op: Passed to torch.distributed.barrier
device_ids: Passed to torch.distributed.barrier
Returns:
Output of torch.distributed.barrier
"""
if _should_rank_record_comm(group):
_emit_call_description('barrier', message_size=1, group=group)
return torch.distributed.untraced_barrier(group, async_op, device_ids)