in sample_workloads/lit-gpt-demo/utilities/monitor_collectives.py [0:0]
def shunt_torch_communication():
_identify_trace_mode()
if _TRACE_MODE == 'none':
if int(os.environ.get("RANK", "0")) == 0:
print('Tracing torch.distributed collectives disabled.', flush=True)
return
_shunt_torch_communication_objects()
_shunt_torch_communication_calls()
if int(os.environ.get("RANK", "0")) == 0:
print('NVTX and print tracing of torch.distributed collectives enabled.',
flush=True)
print(f"{_GPU_SERIAL=}, {_VM_ID=}")
if not _SHOULD_PRINT:
print('Collectives are traced but will not be printed to stdout', flush=True)