def __torch_function__()

in crypten/cuda/cuda_tensor.py [0:0]


    def __torch_function__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}
        if func not in HANDLED_FUNCTIONS or not all(
            issubclass(t, (torch.Tensor, CUDALongTensor)) for t in types
        ):
            args = [t.tensor() if hasattr(t, "tensor") else t for t in args]
            result = func(*args, **kwargs)
            if torch.is_tensor(result):
                return CUDALongTensor(result)
            if isinstance(result, list):
                return [CUDALongTensor(t) if torch.is_tensor(t) else t for t in result]
            if isinstance(result, tuple):
                return tuple(
                    CUDALongTensor(t) if torch.is_tensor(t) else t for t in result
                )
            return result
        return HANDLED_FUNCTIONS[func](*args, **kwargs)