in sample_workloads/lit-gpt-demo/utilities/monitor_collectives.py [0:0]
def traced_all_gather_into_tensor(
output_tensor, input_tensor, group=None, async_op=False):
"""Intercepts invocations of torch.distributed.all_gather_into_tensor.
Similar 'traced_all_gather'
"""
if _should_rank_record_comm(group):
message_size = output_tensor.nelement() * output_tensor.element_size()
_emit_call_description('all_gather', message_size, group)
return torch.distributed.untraced_all_gather_into_tensor(
output_tensor, input_tensor, group, async_op)