def _shunt_torch_communication_calls()

in sample_workloads/lit-gpt-demo/utilities/monitor_collectives.py [0:0]


def _shunt_torch_communication_calls():
  """Replaces torch.distributed.<target_collective> with a traced version.
  """
  target_collectives = [
      'barrier',
      'broadcast_object_list',
      'broadcast',
      'gather',
      'scatter',
      'reduce',
      'reduce_scatter',
      'reduce_scatter_tensor',
      'all_reduce',
      'all_gather',
      'all_gather_into_tensor',
      'all_to_all',
      'all_to_all_single',
      'batch_isend_irecv',
      'isend',
      'irecv',
      'send',
      'recv',
  ]

  this_module = sys.modules[__name__]
  for collective in target_collectives:
    original_fn = getattr(torch.distributed, collective)
    replaced_fn = getattr(this_module, 'traced_' + collective)
    setattr(torch.distributed, 'untraced_' + collective, original_fn)
    setattr(torch.distributed, collective, replaced_fn)