in src/text_clustering.py [0:0]
def fit(self, texts, embeddings=None):
self.texts = texts
if embeddings is None:
logging.info("embedding texts...")
self.embeddings = self.embed(texts)
else:
logging.info("using precomputed embeddings...")
self.embeddings = embeddings
logging.info("building faiss index...")
self.faiss_index = self.build_faiss_index(self.embeddings)
logging.info("projecting with umap...")
self.projections, self.umap_mapper = self.project(self.embeddings)
logging.info("dbscan clustering...")
self.cluster_labels = self.cluster(self.projections)
self.id2cluster = {
index: label for index, label in enumerate(self.cluster_labels)
}
self.label2docs = defaultdict(list)
for i, label in enumerate(self.cluster_labels):
self.label2docs[label].append(i)
self.cluster_centers = {}
for label in self.label2docs.keys():
x = np.mean([self.projections[doc, 0] for doc in self.label2docs[label]])
y = np.mean([self.projections[doc, 1] for doc in self.label2docs[label]])
self.cluster_centers[label] = (x, y)
if self.summary_create:
logging.info("summarizing cluster centers...")
self.cluster_summaries = self.summarize(self.texts, self.cluster_labels)
else:
self.cluster_summaries = None
return self.embeddings, self.cluster_labels, self.cluster_summaries