in training/utils/distributed.py [0:0]
def broadcast_object(obj: Any, src: int = _PRIMARY_RANK, use_disk: bool = True) -> Any:
"""Broadcast an object from a source to all workers.
Args:
obj: Object to broadcast, must be serializable
src: Source rank for broadcast (default is primary)
use_disk: If enabled, removes redundant CPU memory copies by writing to
disk
"""
# Either broadcast from primary to the fleet (default),
# or use the src setting as the original rank
if get_rank() == src:
# Emit data
buffer = io.BytesIO()
torch.save(obj, buffer)
data_view = buffer.getbuffer()
length_tensor = torch.LongTensor([len(data_view)])
length_tensor = broadcast(length_tensor, src=src)
data_tensor = torch.ByteTensor(data_view)
data_tensor = broadcast(data_tensor, src=src)
else:
# Fetch from the source
length_tensor = torch.LongTensor([0])
length_tensor = broadcast(length_tensor, src=src)
data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8)
data_tensor = broadcast(data_tensor, src=src)
if use_disk:
with tempfile.TemporaryFile("r+b") as f:
f.write(data_tensor.numpy())
# remove reference to the data tensor and hope that Python garbage
# collects it
del data_tensor
f.seek(0)
obj = torch.load(f)
else:
buffer = io.BytesIO(data_tensor.numpy())
obj = torch.load(buffer)
return obj