in sparse_autoencoder/train.py [0:0]
def sh_allreduce_scale(self, scaler):
if self.sh_comm is None:
return
if hasattr(scaler, "_scale") and scaler._scale is not None:
self.sh_comm.all_reduce(scaler._scale, op=ReduceOp.MIN, async_op=True)
self.sh_comm.all_reduce(scaler._growth_tracker, op=ReduceOp.MIN, async_op=True)