def _sh_comm_op()

in sparse_autoencoder/train.py [0:0]


    def _sh_comm_op(self, x, op):
        if isinstance(x, (float, int)):
            x = torch.tensor(x, device="cuda")

        if not x.is_cuda:
            x = x.cuda()

        if self.sh_comm is None:
            return x

        out = x.clone()
        self.sh_comm.all_reduce(x, op=op, async_op=True)
        return out