in sparse_autoencoder/train.py [0:0]
def dp_allreduce_(self, autoencoder) -> None:
if self.dp_comm is None:
return
for param in autoencoder.parameters():
if param.grad is not None:
self.dp_comm.all_reduce(maybe_transpose(param.grad), op=ReduceOp.AVG, async_op=True)
# make sure statistics for dead neurons are correct
self.dp_comm.all_reduce( # type: ignore
autoencoder.stats_last_nonzero, op=ReduceOp.MIN, async_op=True
)