in sparse_autoencoder/train.py [0:0]
def sh_allreduce_forward(self, x: torch.Tensor) -> torch.Tensor:
if self.sh_comm is None:
return x
class AllreduceForward(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
assert self.sh_comm is not None
self.sh_comm.all_reduce(input, async_op=True)
return input
@staticmethod
def backward(ctx, grad_output):
return grad_output
return AllreduceForward.apply(x) # type: ignore