in sample_workloads/megatron-gke/docker/monitor_collectives.py [0:0]
def traced_broadcast_object_list(object_list, src=0, group=None, device=None):
"""Intercepts invocations of torch.distributed.broadcast_object_list.
Converts objects to tensor data using the pickle library. Then conducts a
torch.distributed.broadcast call.
Args:
object_list: Passed to torch.distributed.broadcast_object_list
src: Passed to torch.distributed.broadcast_object_list
group: Passed to torch.distributed.broadcast_object_list
device: Passed to torch.distributed.broadcast_object_list
Returns:
Output of torch.distributed.broadcast_object_list
"""
if _should_rank_record_comm(group, root_rank=src):
message_size = 0
for obj in object_list:
# Note: This computation is sadly redundant with underlying call :(
# For now we don't expect this invocation to be in critical path.
buf = io.BytesIO()
pickle.Pickler(buf).dump(obj)
message_size += buf.getbuffer().nbytes
_emit_call_description(
'broadcast_object_list', message_size, group, root_rank=src)
return torch.distributed.untraced_broadcast_object_list(
object_list, src, group, device)