def traced_broadcast_object_list()

in sample_workloads/megatron-gke/docker/monitor_collectives.py [0:0]


def traced_broadcast_object_list(object_list, src=0, group=None, device=None):
  """Intercepts invocations of torch.distributed.broadcast_object_list.

  Converts objects to tensor data using the pickle library. Then conducts a
  torch.distributed.broadcast call.

  Args:
    object_list: Passed to torch.distributed.broadcast_object_list
    src: Passed to torch.distributed.broadcast_object_list
    group: Passed to torch.distributed.broadcast_object_list
    device: Passed to torch.distributed.broadcast_object_list

  Returns:
    Output of torch.distributed.broadcast_object_list
  """

  if _should_rank_record_comm(group, root_rank=src):
    message_size = 0
    for obj in object_list:
      # Note: This computation is sadly redundant with underlying call :(
      # For now we don't expect this invocation to be in critical path.
      buf = io.BytesIO()
      pickle.Pickler(buf).dump(obj)
      message_size += buf.getbuffer().nbytes
    _emit_call_description(
        'broadcast_object_list', message_size, group, root_rank=src)

  return torch.distributed.untraced_broadcast_object_list(
      object_list, src, group, device)