def dp_allreduce_()

in sparse_autoencoder/train.py [0:0]


    def dp_allreduce_(self, autoencoder) -> None:
        if self.dp_comm is None:
            return

        for param in autoencoder.parameters():
            if param.grad is not None:
                self.dp_comm.all_reduce(maybe_transpose(param.grad), op=ReduceOp.AVG, async_op=True)

        # make sure statistics for dead neurons are correct
        self.dp_comm.all_reduce(  # type: ignore
            autoencoder.stats_last_nonzero, op=ReduceOp.MIN, async_op=True
        )