def handle_torch_Index()

in contrib/torch_utils.py [0:0]


def handle_torch_Index(the_class):
    def torch_replacement_add(self, x):
        if type(x) is np.ndarray:
            # forward to faiss __init__.py base method
            return self.add_numpy(x)

        assert type(x) is torch.Tensor
        n, d = x.shape
        assert d == self.d
        x_ptr = swig_ptr_from_FloatTensor(x)

        if x.is_cuda:
            assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'

            # On the GPU, use proper stream ordering
            with using_stream(self.getResources()):
                self.add_c(n, x_ptr)
        else:
            # CPU torch
            self.add_c(n, x_ptr)

    def torch_replacement_add_with_ids(self, x, ids):
        if type(x) is np.ndarray:
            # forward to faiss __init__.py base method
            return self.add_with_ids_numpy(x, ids)

        assert type(x) is torch.Tensor
        n, d = x.shape
        assert d == self.d
        x_ptr = swig_ptr_from_FloatTensor(x)

        assert type(ids) is torch.Tensor
        assert ids.shape == (n, ), 'not same number of vectors as ids'
        ids_ptr = swig_ptr_from_IndicesTensor(ids)

        if x.is_cuda:
            assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'

            # On the GPU, use proper stream ordering
            with using_stream(self.getResources()):
                self.add_with_ids_c(n, x_ptr, ids_ptr)
        else:
            # CPU torch
            self.add_with_ids_c(n, x_ptr, ids_ptr)

    def torch_replacement_assign(self, x, k, labels=None):
        if type(x) is np.ndarray:
            # forward to faiss __init__.py base method
            return self.assign_numpy(x, k, labels)

        assert type(x) is torch.Tensor
        n, d = x.shape
        assert d == self.d
        x_ptr = swig_ptr_from_FloatTensor(x)

        if labels is None:
            labels = torch.empty(n, k, device=x.device, dtype=torch.int64)
        else:
            assert type(labels) is torch.Tensor
            assert labels.shape == (n, k)
        L_ptr = swig_ptr_from_IndicesTensor(labels)

        if x.is_cuda:
            assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'

            # On the GPU, use proper stream ordering
            with using_stream(self.getResources()):
                self.assign_c(n, x_ptr, L_ptr, k)
        else:
            # CPU torch
            self.assign_c(n, x_ptr, L_ptr, k)

        return labels

    def torch_replacement_train(self, x):
        if type(x) is np.ndarray:
            # forward to faiss __init__.py base method
            return self.train_numpy(x)

        assert type(x) is torch.Tensor
        n, d = x.shape
        assert d == self.d
        x_ptr = swig_ptr_from_FloatTensor(x)

        if x.is_cuda:
            assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'

            # On the GPU, use proper stream ordering
            with using_stream(self.getResources()):
                self.train_c(n, x_ptr)
        else:
            # CPU torch
            self.train_c(n, x_ptr)

    def torch_replacement_search(self, x, k, D=None, I=None):
        if type(x) is np.ndarray:
            # forward to faiss __init__.py base method
            return self.search_numpy(x, k, D, I)

        assert type(x) is torch.Tensor
        n, d = x.shape
        assert d == self.d
        x_ptr = swig_ptr_from_FloatTensor(x)

        if D is None:
            D = torch.empty(n, k, device=x.device, dtype=torch.float32)
        else:
            assert type(D) is torch.Tensor
            assert D.shape == (n, k)
        D_ptr = swig_ptr_from_FloatTensor(D)

        if I is None:
            I = torch.empty(n, k, device=x.device, dtype=torch.int64)
        else:
            assert type(I) is torch.Tensor
            assert I.shape == (n, k)
        I_ptr = swig_ptr_from_IndicesTensor(I)

        if x.is_cuda:
            assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'

            # On the GPU, use proper stream ordering
            with using_stream(self.getResources()):
                self.search_c(n, x_ptr, k, D_ptr, I_ptr)
        else:
            # CPU torch
            self.search_c(n, x_ptr, k, D_ptr, I_ptr)

        return D, I

    def torch_replacement_search_and_reconstruct(self, x, k, D=None, I=None, R=None):
        if type(x) is np.ndarray:
            # Forward to faiss __init__.py base method
            return self.search_and_reconstruct_numpy(x, k, D, I, R)

        assert type(x) is torch.Tensor
        n, d = x.shape
        assert d == self.d
        x_ptr = swig_ptr_from_FloatTensor(x)

        if D is None:
            D = torch.empty(n, k, device=x.device, dtype=torch.float32)
        else:
            assert type(D) is torch.Tensor
            assert D.shape == (n, k)
        D_ptr = swig_ptr_from_FloatTensor(D)

        if I is None:
            I = torch.empty(n, k, device=x.device, dtype=torch.int64)
        else:
            assert type(I) is torch.Tensor
            assert I.shape == (n, k)
        I_ptr = swig_ptr_from_IndicesTensor(I)

        if R is None:
            R = torch.empty(n, k, d, device=x.device, dtype=torch.float32)
        else:
            assert type(R) is torch.Tensor
            assert R.shape == (n, k, d)
        R_ptr = swig_ptr_from_FloatTensor(R)

        if x.is_cuda:
            assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'

            # On the GPU, use proper stream ordering
            with using_stream(self.getResources()):
                self.search_and_reconstruct_c(n, x_ptr, k, D_ptr, I_ptr, R_ptr)
        else:
            # CPU torch
            self.search_and_reconstruct_c(n, x_ptr, k, D_ptr, I_ptr, R_ptr)

        return D, I, R

    def torch_replacement_remove_ids(self, x):
        # Not yet implemented
        assert type(x) is not torch.Tensor, 'remove_ids not yet implemented for torch'
        return self.remove_ids_numpy(x)

    def torch_replacement_reconstruct(self, key, x=None):
        # No tensor inputs are required, but with importing this module, we
        # assume that the default should be torch tensors. If we are passed a
        # numpy array, however, assume that the user is overriding this default
        if (x is not None) and (type(x) is np.ndarray):
            # Forward to faiss __init__.py base method
            return self.reconstruct_numpy(key, x)

        # If the index is a CPU index, the default device is CPU, otherwise we
        # produce a GPU tensor
        device = torch.device('cpu')
        if hasattr(self, 'getDevice'):
            # same device as the index
            device = torch.device('cuda', self.getDevice())

        if x is None:
            x = torch.empty(self.d, device=device, dtype=torch.float32)
        else:
            assert type(x) is torch.Tensor
            assert x.shape == (self.d, )
        x_ptr = swig_ptr_from_FloatTensor(x)

        if x.is_cuda:
            assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'

            # On the GPU, use proper stream ordering
            with using_stream(self.getResources()):
                self.reconstruct_c(key, x_ptr)
        else:
            # CPU torch
            self.reconstruct_c(key, x_ptr)

        return x

    def torch_replacement_reconstruct_n(self, n0, ni, x=None):
        # No tensor inputs are required, but with importing this module, we
        # assume that the default should be torch tensors. If we are passed a
        # numpy array, however, assume that the user is overriding this default
        if (x is not None) and (type(x) is np.ndarray):
            # Forward to faiss __init__.py base method
            return self.reconstruct_n_numpy(n0, ni, x)

        # If the index is a CPU index, the default device is CPU, otherwise we
        # produce a GPU tensor
        device = torch.device('cpu')
        if hasattr(self, 'getDevice'):
            # same device as the index
            device = torch.device('cuda', self.getDevice())

        if x is None:
            x = torch.empty(ni, self.d, device=device, dtype=torch.float32)
        else:
            assert type(x) is torch.Tensor
            assert x.shape == (ni, self.d)
        x_ptr = swig_ptr_from_FloatTensor(x)

        if x.is_cuda:
            assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'

            # On the GPU, use proper stream ordering
            with using_stream(self.getResources()):
                self.reconstruct_n_c(n0, ni, x_ptr)
        else:
            # CPU torch
            self.reconstruct_n_c(n0, ni, x_ptr)

        return x

    def torch_replacement_update_vectors(self, keys, x):
        if type(keys) is np.ndarray:
            # Forward to faiss __init__.py base method
            return self.update_vectors_numpy(keys, x)

        assert type(keys) is torch.Tensor
        (n, ) = keys.shape
        keys_ptr = swig_ptr_from_IndicesTensor(keys)

        assert type(x) is torch.Tensor
        assert x.shape == (n, self.d)
        x_ptr = swig_ptr_from_FloatTensor(x)

        if x.is_cuda:
            assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'

            # On the GPU, use proper stream ordering
            with using_stream(self.getResources()):
                self.update_vectors_c(n, keys_ptr, x_ptr)
        else:
            # CPU torch
            self.update_vectors_c(n, keys_ptr, x_ptr)

    # Until the GPU version is implemented, we do not support pre-allocated
    # output buffers
    def torch_replacement_range_search(self, x, thresh):
        if type(x) is np.ndarray:
            # Forward to faiss __init__.py base method
            return self.range_search_numpy(x, thresh)

        assert type(x) is torch.Tensor
        n, d = x.shape
        assert d == self.d
        x_ptr = swig_ptr_from_FloatTensor(x)

        assert not x.is_cuda, 'Range search using GPU tensor not yet implemented'
        assert not hasattr(self, 'getDevice'), 'Range search on GPU index not yet implemented'

        res = faiss.RangeSearchResult(n)
        self.range_search_c(n, x_ptr, thresh, res)

        # get pointers and copy them
        # FIXME: no rev_swig_ptr equivalent for torch.Tensor, just convert
        # np to torch
        # NOTE: torch does not support np.uint64, just np.int64
        lims = torch.from_numpy(faiss.rev_swig_ptr(res.lims, n + 1).copy().astype('int64'))
        nd = int(lims[-1])
        D = torch.from_numpy(faiss.rev_swig_ptr(res.distances, nd).copy())
        I = torch.from_numpy(faiss.rev_swig_ptr(res.labels, nd).copy())

        return lims, D, I

    def torch_replacement_sa_encode(self, x, codes=None):
        if type(x) is np.ndarray:
            # Forward to faiss __init__.py base method
            return self.sa_encode_numpy(x, codes)

        assert type(x) is torch.Tensor
        n, d = x.shape
        assert d == self.d
        x_ptr = swig_ptr_from_FloatTensor(x)

        if codes is None:
            codes = torch.empty(n, self.sa_code_size(), dtype=torch.uint8)
        else:
            assert codes.shape == (n, self.sa_code_size())
        codes_ptr = swig_ptr_from_UInt8Tensor(codes)

        if x.is_cuda:
            assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'

            # On the GPU, use proper stream ordering
            with using_stream(self.getResources()):
                self.sa_encode_c(n, x_ptr, codes_ptr)
        else:
            # CPU torch
            self.sa_encode_c(n, x_ptr, codes_ptr)

        return codes

    def torch_replacement_sa_decode(self, codes, x=None):
        if type(codes) is np.ndarray:
            # Forward to faiss __init__.py base method
            return self.sa_decode_numpy(codes, x)

        assert type(codes) is torch.Tensor
        n, cs = codes.shape
        assert cs == self.sa_code_size()
        codes_ptr = swig_ptr_from_UInt8Tensor(codes)

        if x is None:
            x = torch.empty(n, self.d, dtype=torch.float32)
        else:
            assert type(x) is torch.Tensor
            assert x.shape == (n, self.d)
        x_ptr = swig_ptr_from_FloatTensor(x)

        if codes.is_cuda:
            assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'

            # On the GPU, use proper stream ordering
            with using_stream(self.getResources()):
                self.sa_decode_c(n, codes_ptr, x_ptr)
        else:
            # CPU torch
            self.sa_decode_c(n, codes_ptr, x_ptr)

        return x


    torch_replace_method(the_class, 'add', torch_replacement_add)
    torch_replace_method(the_class, 'add_with_ids', torch_replacement_add_with_ids)
    torch_replace_method(the_class, 'assign', torch_replacement_assign)
    torch_replace_method(the_class, 'train', torch_replacement_train)
    torch_replace_method(the_class, 'search', torch_replacement_search)
    torch_replace_method(the_class, 'remove_ids', torch_replacement_remove_ids)
    torch_replace_method(the_class, 'reconstruct', torch_replacement_reconstruct)
    torch_replace_method(the_class, 'reconstruct_n', torch_replacement_reconstruct_n)
    torch_replace_method(the_class, 'range_search', torch_replacement_range_search)
    torch_replace_method(the_class, 'update_vectors', torch_replacement_update_vectors,
                         ignore_missing=True)
    torch_replace_method(the_class, 'search_and_reconstruct',
                         torch_replacement_search_and_reconstruct, ignore_missing=True)
    torch_replace_method(the_class, 'sa_encode', torch_replacement_sa_encode)
    torch_replace_method(the_class, 'sa_decode', torch_replacement_sa_decode)