in src/text_clustering.py [0:0]
def infer(self, texts, top_k=1):
embeddings = self.embed(texts)
dist, neighbours = self.faiss_index.search(embeddings, top_k)
inferred_labels = []
for i in tqdm(range(embeddings.shape[0])):
labels = [self.cluster_labels[doc] for doc in neighbours[i]]
inferred_labels.append(Counter(labels).most_common(1)[0][0])
return inferred_labels, embeddings