in sparse_autoencoder/train.py [0:0]
def _sh_comm_op(self, x, op):
if isinstance(x, (float, int)):
x = torch.tensor(x, device="cuda")
if not x.is_cuda:
x = x.cuda()
if self.sh_comm is None:
return x
out = x.clone()
self.sh_comm.all_reduce(x, op=op, async_op=True)
return out