sparse_autoencoder/train.py [103:112]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                    replica_shard_to_rank(
                        replica_idx=0,
                        shard_idx=self.sh_rank,
                        n_op_shards=self.n_op_shards,
                    ),
                )
        
        if self.sh_comm is not None:
            # pre_bias is the same across all shards
            self.sh_comm.broadcast(
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



sparse_autoencoder/train.py [163:171]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                replica_shard_to_rank(
                    replica_idx=0,
                    shard_idx=self.sh_rank,
                    n_op_shards=self.n_op_shards,
                ),
            )

        if self.sh_comm is not None:
            self.sh_comm.broadcast(
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



