def init_broadcast_()

in sparse_autoencoder/train.py [0:0]


    def init_broadcast_(self, autoencoder):
        if self.dp_comm is not None:
            for p in autoencoder.parameters():
                self.dp_comm.broadcast(
                    maybe_transpose(p.data),
                    replica_shard_to_rank(
                        replica_idx=0,
                        shard_idx=self.sh_rank,
                        n_op_shards=self.n_op_shards,
                    ),
                )
        
        if self.sh_comm is not None:
            # pre_bias is the same across all shards
            self.sh_comm.broadcast(
                autoencoder.pre_bias.data,
                replica_shard_to_rank(
                    replica_idx=self.dp_rank,
                    shard_idx=0,
                    n_op_shards=self.n_op_shards,
                ),
            )