def all_broadcast()

in sparse_autoencoder/train.py [0:0]


    def all_broadcast(self, x: torch.Tensor) -> torch.Tensor:
        if self.dp_comm is not None:
            self.dp_comm.broadcast(
                x,
                replica_shard_to_rank(
                    replica_idx=0,
                    shard_idx=self.sh_rank,
                    n_op_shards=self.n_op_shards,
                ),
            )

        if self.sh_comm is not None:
            self.sh_comm.broadcast(
                x,
                replica_shard_to_rank(
                    replica_idx=self.dp_rank,
                    shard_idx=0,
                    n_op_shards=self.n_op_shards,
                ),
            )

        return x