in sparse_autoencoder/train.py [0:0]
def init_broadcast_(self, autoencoder):
if self.dp_comm is not None:
for p in autoencoder.parameters():
self.dp_comm.broadcast(
maybe_transpose(p.data),
replica_shard_to_rank(
replica_idx=0,
shard_idx=self.sh_rank,
n_op_shards=self.n_op_shards,
),
)
if self.sh_comm is not None:
# pre_bias is the same across all shards
self.sh_comm.broadcast(
autoencoder.pre_bias.data,
replica_shard_to_rank(
replica_idx=self.dp_rank,
shard_idx=0,
n_op_shards=self.n_op_shards,
),
)