in src/utils.py [0:0]
def get_knn_faiss(xq, xb, k, split_base=-1, distance='dot_product'):
if split_base == -1:
return __get_knn_faiss(xq, xb, k, distance)
else:
distances, indices = [], []
bases = torch.chunk(xb, split_base)
offset = 0
for base in bases:
D, I = __get_knn_faiss(xq, base, k, distance)
I += offset
offset += base.size(0)
distances.append(D)
indices.append(I)
distances = torch.cat(distances, dim=1)
indices = torch.cat(indices, dim=1)
n = distances.size(0)
# distances can be L2 distances or dot product
factor = -1 if distance == 'dot_product' else 1
distances *= factor
order = distances.argsort(dim=1)[:, :k]
I = indices[torch.arange(n).view(-1, 1), order]
D = distances[torch.arange(n).view(-1, 1), order]
D *= factor
return D, I