def shunt_torch_communication()

in sample_workloads/lit-gpt-demo/utilities/monitor_collectives.py [0:0]


def shunt_torch_communication():
  _identify_trace_mode()
  if _TRACE_MODE == 'none':
    if int(os.environ.get("RANK", "0")) == 0:
      print('Tracing torch.distributed collectives disabled.', flush=True)
    return

  _shunt_torch_communication_objects()
  _shunt_torch_communication_calls()

  if int(os.environ.get("RANK", "0")) == 0:
    print('NVTX and print tracing of torch.distributed collectives enabled.',
          flush=True)
    print(f"{_GPU_SERIAL=}, {_VM_ID=}")

    if not _SHOULD_PRINT:
      print('Collectives are traced but will not be printed to stdout', flush=True)