def sh_allreduce_forward()

in sparse_autoencoder/train.py [0:0]


    def sh_allreduce_forward(self, x: torch.Tensor) -> torch.Tensor:
        if self.sh_comm is None:
            return x

        class AllreduceForward(torch.autograd.Function):
            @staticmethod
            def forward(ctx, input):
                assert self.sh_comm is not None
                self.sh_comm.all_reduce(input, async_op=True)
                return input

            @staticmethod
            def backward(ctx, grad_output):
                return grad_output

        return AllreduceForward.apply(x)  # type: ignore