in sparse_autoencoder/train.py [0:0]
def sh_allreduce_backward(self, x: torch.Tensor) -> torch.Tensor:
if self.sh_comm is None:
return x
class AllreduceBackward(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
return input
@staticmethod
def backward(ctx, grad_output):
grad_output = grad_output.clone()
assert self.sh_comm is not None
self.sh_comm.all_reduce(grad_output, async_op=True)
return grad_output
return AllreduceBackward.apply(x) # type: ignore