in sparse_autoencoder/train.py [0:0]
def all_broadcast(self, x: torch.Tensor) -> torch.Tensor:
if self.dp_comm is not None:
self.dp_comm.broadcast(
x,
replica_shard_to_rank(
replica_idx=0,
shard_idx=self.sh_rank,
n_op_shards=self.n_op_shards,
),
)
if self.sh_comm is not None:
self.sh_comm.broadcast(
x,
replica_shard_to_rank(
replica_idx=self.dp_rank,
shard_idx=0,
n_op_shards=self.n_op_shards,
),
)
return x