in src/utils.py [0:0]
def get_nn_avg_dist(emb, query, knn):
"""
Compute the average distance of the `knn` nearest neighbors
for a given set of embeddings and queries.
Use Faiss if available.
"""
if FAISS_AVAILABLE:
emb = emb.cpu().numpy()
query = query.cpu().numpy()
if hasattr(faiss, 'StandardGpuResources'):
# gpu mode
res = faiss.StandardGpuResources()
config = faiss.GpuIndexFlatConfig()
config.device = 0
index = faiss.GpuIndexFlatIP(res, emb.shape[1], config)
else:
# cpu mode
index = faiss.IndexFlatIP(emb.shape[1])
index.add(emb)
distances, _ = index.search(query, knn)
return distances.mean(1)
else:
bs = 1024
all_distances = []
emb = emb.transpose(0, 1).contiguous()
for i in range(0, query.shape[0], bs):
distances = query[i:i + bs].mm(emb)
best_distances, _ = distances.topk(knn, dim=1, largest=True, sorted=True)
all_distances.append(best_distances.mean(1).cpu())
all_distances = torch.cat(all_distances)
return all_distances.numpy()