in sample_workloads/lit-gpt-demo/utilities/monitor_collectives.py [0:0]
def _shunt_torch_communication_calls():
"""Replaces torch.distributed.<target_collective> with a traced version.
"""
target_collectives = [
'barrier',
'broadcast_object_list',
'broadcast',
'gather',
'scatter',
'reduce',
'reduce_scatter',
'reduce_scatter_tensor',
'all_reduce',
'all_gather',
'all_gather_into_tensor',
'all_to_all',
'all_to_all_single',
'batch_isend_irecv',
'isend',
'irecv',
'send',
'recv',
]
this_module = sys.modules[__name__]
for collective in target_collectives:
original_fn = getattr(torch.distributed, collective)
replaced_fn = getattr(this_module, 'traced_' + collective)
setattr(torch.distributed, 'untraced_' + collective, original_fn)
setattr(torch.distributed, collective, replaced_fn)