in crypten/cuda/cuda_tensor.py [0:0]
def __torch_function__(self, func, types, args=(), kwargs=None):
if kwargs is None:
kwargs = {}
if func not in HANDLED_FUNCTIONS or not all(
issubclass(t, (torch.Tensor, CUDALongTensor)) for t in types
):
args = [t.tensor() if hasattr(t, "tensor") else t for t in args]
result = func(*args, **kwargs)
if torch.is_tensor(result):
return CUDALongTensor(result)
if isinstance(result, list):
return [CUDALongTensor(t) if torch.is_tensor(t) else t for t in result]
if isinstance(result, tuple):
return tuple(
CUDALongTensor(t) if torch.is_tensor(t) else t for t in result
)
return result
return HANDLED_FUNCTIONS[func](*args, **kwargs)