in contrib/torch_utils.py [0:0]
def handle_torch_Index(the_class):
def torch_replacement_add(self, x):
if type(x) is np.ndarray:
# forward to faiss __init__.py base method
return self.add_numpy(x)
assert type(x) is torch.Tensor
n, d = x.shape
assert d == self.d
x_ptr = swig_ptr_from_FloatTensor(x)
if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.add_c(n, x_ptr)
else:
# CPU torch
self.add_c(n, x_ptr)
def torch_replacement_add_with_ids(self, x, ids):
if type(x) is np.ndarray:
# forward to faiss __init__.py base method
return self.add_with_ids_numpy(x, ids)
assert type(x) is torch.Tensor
n, d = x.shape
assert d == self.d
x_ptr = swig_ptr_from_FloatTensor(x)
assert type(ids) is torch.Tensor
assert ids.shape == (n, ), 'not same number of vectors as ids'
ids_ptr = swig_ptr_from_IndicesTensor(ids)
if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.add_with_ids_c(n, x_ptr, ids_ptr)
else:
# CPU torch
self.add_with_ids_c(n, x_ptr, ids_ptr)
def torch_replacement_assign(self, x, k, labels=None):
if type(x) is np.ndarray:
# forward to faiss __init__.py base method
return self.assign_numpy(x, k, labels)
assert type(x) is torch.Tensor
n, d = x.shape
assert d == self.d
x_ptr = swig_ptr_from_FloatTensor(x)
if labels is None:
labels = torch.empty(n, k, device=x.device, dtype=torch.int64)
else:
assert type(labels) is torch.Tensor
assert labels.shape == (n, k)
L_ptr = swig_ptr_from_IndicesTensor(labels)
if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.assign_c(n, x_ptr, L_ptr, k)
else:
# CPU torch
self.assign_c(n, x_ptr, L_ptr, k)
return labels
def torch_replacement_train(self, x):
if type(x) is np.ndarray:
# forward to faiss __init__.py base method
return self.train_numpy(x)
assert type(x) is torch.Tensor
n, d = x.shape
assert d == self.d
x_ptr = swig_ptr_from_FloatTensor(x)
if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.train_c(n, x_ptr)
else:
# CPU torch
self.train_c(n, x_ptr)
def torch_replacement_search(self, x, k, D=None, I=None):
if type(x) is np.ndarray:
# forward to faiss __init__.py base method
return self.search_numpy(x, k, D, I)
assert type(x) is torch.Tensor
n, d = x.shape
assert d == self.d
x_ptr = swig_ptr_from_FloatTensor(x)
if D is None:
D = torch.empty(n, k, device=x.device, dtype=torch.float32)
else:
assert type(D) is torch.Tensor
assert D.shape == (n, k)
D_ptr = swig_ptr_from_FloatTensor(D)
if I is None:
I = torch.empty(n, k, device=x.device, dtype=torch.int64)
else:
assert type(I) is torch.Tensor
assert I.shape == (n, k)
I_ptr = swig_ptr_from_IndicesTensor(I)
if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.search_c(n, x_ptr, k, D_ptr, I_ptr)
else:
# CPU torch
self.search_c(n, x_ptr, k, D_ptr, I_ptr)
return D, I
def torch_replacement_search_and_reconstruct(self, x, k, D=None, I=None, R=None):
if type(x) is np.ndarray:
# Forward to faiss __init__.py base method
return self.search_and_reconstruct_numpy(x, k, D, I, R)
assert type(x) is torch.Tensor
n, d = x.shape
assert d == self.d
x_ptr = swig_ptr_from_FloatTensor(x)
if D is None:
D = torch.empty(n, k, device=x.device, dtype=torch.float32)
else:
assert type(D) is torch.Tensor
assert D.shape == (n, k)
D_ptr = swig_ptr_from_FloatTensor(D)
if I is None:
I = torch.empty(n, k, device=x.device, dtype=torch.int64)
else:
assert type(I) is torch.Tensor
assert I.shape == (n, k)
I_ptr = swig_ptr_from_IndicesTensor(I)
if R is None:
R = torch.empty(n, k, d, device=x.device, dtype=torch.float32)
else:
assert type(R) is torch.Tensor
assert R.shape == (n, k, d)
R_ptr = swig_ptr_from_FloatTensor(R)
if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.search_and_reconstruct_c(n, x_ptr, k, D_ptr, I_ptr, R_ptr)
else:
# CPU torch
self.search_and_reconstruct_c(n, x_ptr, k, D_ptr, I_ptr, R_ptr)
return D, I, R
def torch_replacement_remove_ids(self, x):
# Not yet implemented
assert type(x) is not torch.Tensor, 'remove_ids not yet implemented for torch'
return self.remove_ids_numpy(x)
def torch_replacement_reconstruct(self, key, x=None):
# No tensor inputs are required, but with importing this module, we
# assume that the default should be torch tensors. If we are passed a
# numpy array, however, assume that the user is overriding this default
if (x is not None) and (type(x) is np.ndarray):
# Forward to faiss __init__.py base method
return self.reconstruct_numpy(key, x)
# If the index is a CPU index, the default device is CPU, otherwise we
# produce a GPU tensor
device = torch.device('cpu')
if hasattr(self, 'getDevice'):
# same device as the index
device = torch.device('cuda', self.getDevice())
if x is None:
x = torch.empty(self.d, device=device, dtype=torch.float32)
else:
assert type(x) is torch.Tensor
assert x.shape == (self.d, )
x_ptr = swig_ptr_from_FloatTensor(x)
if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.reconstruct_c(key, x_ptr)
else:
# CPU torch
self.reconstruct_c(key, x_ptr)
return x
def torch_replacement_reconstruct_n(self, n0, ni, x=None):
# No tensor inputs are required, but with importing this module, we
# assume that the default should be torch tensors. If we are passed a
# numpy array, however, assume that the user is overriding this default
if (x is not None) and (type(x) is np.ndarray):
# Forward to faiss __init__.py base method
return self.reconstruct_n_numpy(n0, ni, x)
# If the index is a CPU index, the default device is CPU, otherwise we
# produce a GPU tensor
device = torch.device('cpu')
if hasattr(self, 'getDevice'):
# same device as the index
device = torch.device('cuda', self.getDevice())
if x is None:
x = torch.empty(ni, self.d, device=device, dtype=torch.float32)
else:
assert type(x) is torch.Tensor
assert x.shape == (ni, self.d)
x_ptr = swig_ptr_from_FloatTensor(x)
if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.reconstruct_n_c(n0, ni, x_ptr)
else:
# CPU torch
self.reconstruct_n_c(n0, ni, x_ptr)
return x
def torch_replacement_update_vectors(self, keys, x):
if type(keys) is np.ndarray:
# Forward to faiss __init__.py base method
return self.update_vectors_numpy(keys, x)
assert type(keys) is torch.Tensor
(n, ) = keys.shape
keys_ptr = swig_ptr_from_IndicesTensor(keys)
assert type(x) is torch.Tensor
assert x.shape == (n, self.d)
x_ptr = swig_ptr_from_FloatTensor(x)
if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.update_vectors_c(n, keys_ptr, x_ptr)
else:
# CPU torch
self.update_vectors_c(n, keys_ptr, x_ptr)
# Until the GPU version is implemented, we do not support pre-allocated
# output buffers
def torch_replacement_range_search(self, x, thresh):
if type(x) is np.ndarray:
# Forward to faiss __init__.py base method
return self.range_search_numpy(x, thresh)
assert type(x) is torch.Tensor
n, d = x.shape
assert d == self.d
x_ptr = swig_ptr_from_FloatTensor(x)
assert not x.is_cuda, 'Range search using GPU tensor not yet implemented'
assert not hasattr(self, 'getDevice'), 'Range search on GPU index not yet implemented'
res = faiss.RangeSearchResult(n)
self.range_search_c(n, x_ptr, thresh, res)
# get pointers and copy them
# FIXME: no rev_swig_ptr equivalent for torch.Tensor, just convert
# np to torch
# NOTE: torch does not support np.uint64, just np.int64
lims = torch.from_numpy(faiss.rev_swig_ptr(res.lims, n + 1).copy().astype('int64'))
nd = int(lims[-1])
D = torch.from_numpy(faiss.rev_swig_ptr(res.distances, nd).copy())
I = torch.from_numpy(faiss.rev_swig_ptr(res.labels, nd).copy())
return lims, D, I
def torch_replacement_sa_encode(self, x, codes=None):
if type(x) is np.ndarray:
# Forward to faiss __init__.py base method
return self.sa_encode_numpy(x, codes)
assert type(x) is torch.Tensor
n, d = x.shape
assert d == self.d
x_ptr = swig_ptr_from_FloatTensor(x)
if codes is None:
codes = torch.empty(n, self.sa_code_size(), dtype=torch.uint8)
else:
assert codes.shape == (n, self.sa_code_size())
codes_ptr = swig_ptr_from_UInt8Tensor(codes)
if x.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.sa_encode_c(n, x_ptr, codes_ptr)
else:
# CPU torch
self.sa_encode_c(n, x_ptr, codes_ptr)
return codes
def torch_replacement_sa_decode(self, codes, x=None):
if type(codes) is np.ndarray:
# Forward to faiss __init__.py base method
return self.sa_decode_numpy(codes, x)
assert type(codes) is torch.Tensor
n, cs = codes.shape
assert cs == self.sa_code_size()
codes_ptr = swig_ptr_from_UInt8Tensor(codes)
if x is None:
x = torch.empty(n, self.d, dtype=torch.float32)
else:
assert type(x) is torch.Tensor
assert x.shape == (n, self.d)
x_ptr = swig_ptr_from_FloatTensor(x)
if codes.is_cuda:
assert hasattr(self, 'getDevice'), 'GPU tensor on CPU index not allowed'
# On the GPU, use proper stream ordering
with using_stream(self.getResources()):
self.sa_decode_c(n, codes_ptr, x_ptr)
else:
# CPU torch
self.sa_decode_c(n, codes_ptr, x_ptr)
return x
torch_replace_method(the_class, 'add', torch_replacement_add)
torch_replace_method(the_class, 'add_with_ids', torch_replacement_add_with_ids)
torch_replace_method(the_class, 'assign', torch_replacement_assign)
torch_replace_method(the_class, 'train', torch_replacement_train)
torch_replace_method(the_class, 'search', torch_replacement_search)
torch_replace_method(the_class, 'remove_ids', torch_replacement_remove_ids)
torch_replace_method(the_class, 'reconstruct', torch_replacement_reconstruct)
torch_replace_method(the_class, 'reconstruct_n', torch_replacement_reconstruct_n)
torch_replace_method(the_class, 'range_search', torch_replacement_range_search)
torch_replace_method(the_class, 'update_vectors', torch_replacement_update_vectors,
ignore_missing=True)
torch_replace_method(the_class, 'search_and_reconstruct',
torch_replacement_search_and_reconstruct, ignore_missing=True)
torch_replace_method(the_class, 'sa_encode', torch_replacement_sa_encode)
torch_replace_method(the_class, 'sa_decode', torch_replacement_sa_decode)