def sh_allreduce_backward()

in sparse_autoencoder/train.py [0:0]


    def sh_allreduce_backward(self, x: torch.Tensor) -> torch.Tensor:
        if self.sh_comm is None:
            return x

        class AllreduceBackward(torch.autograd.Function):
            @staticmethod
            def forward(ctx, input):
                return input

            @staticmethod
            def backward(ctx, grad_output):
                grad_output = grad_output.clone()
                assert self.sh_comm is not None
                self.sh_comm.all_reduce(grad_output, async_op=True)
                return grad_output

        return AllreduceBackward.apply(x)  # type: ignore