def get_knn_faiss()

in src/utils.py [0:0]


def get_knn_faiss(xq, xb, k, split_base=-1, distance='dot_product'):
    if split_base == -1:
        return __get_knn_faiss(xq, xb, k, distance)
    else:
        distances, indices = [], []
        bases = torch.chunk(xb, split_base)

        offset = 0
        for base in bases:
            D, I = __get_knn_faiss(xq, base, k, distance)
            I += offset
            offset += base.size(0)

            distances.append(D)
            indices.append(I)

        distances = torch.cat(distances, dim=1)
        indices = torch.cat(indices, dim=1)
        n = distances.size(0)

        # distances can be L2 distances or dot product
        factor = -1 if distance == 'dot_product' else 1
        distances *= factor
        order = distances.argsort(dim=1)[:, :k]

        I = indices[torch.arange(n).view(-1, 1), order]
        D = distances[torch.arange(n).view(-1, 1), order]
        D *= factor

        return D, I