def broadcast_object()

in training/utils/distributed.py [0:0]


def broadcast_object(obj: Any, src: int = _PRIMARY_RANK, use_disk: bool = True) -> Any:
    """Broadcast an object from a source to all workers.

    Args:
        obj: Object to broadcast, must be serializable
        src: Source rank for broadcast (default is primary)
        use_disk: If enabled, removes redundant CPU memory copies by writing to
            disk
    """
    # Either broadcast from primary to the fleet (default),
    # or use the src setting as the original rank
    if get_rank() == src:
        # Emit data
        buffer = io.BytesIO()
        torch.save(obj, buffer)
        data_view = buffer.getbuffer()
        length_tensor = torch.LongTensor([len(data_view)])
        length_tensor = broadcast(length_tensor, src=src)
        data_tensor = torch.ByteTensor(data_view)
        data_tensor = broadcast(data_tensor, src=src)
    else:
        # Fetch from the source
        length_tensor = torch.LongTensor([0])
        length_tensor = broadcast(length_tensor, src=src)
        data_tensor = torch.empty([length_tensor.item()], dtype=torch.uint8)
        data_tensor = broadcast(data_tensor, src=src)
        if use_disk:
            with tempfile.TemporaryFile("r+b") as f:
                f.write(data_tensor.numpy())
                # remove reference to the data tensor and hope that Python garbage
                # collects it
                del data_tensor
                f.seek(0)
                obj = torch.load(f)
        else:
            buffer = io.BytesIO(data_tensor.numpy())
            obj = torch.load(buffer)
    return obj