in data_measurements/embeddings/embeddings.py [0:0]
def find_cluster_beam(self, sentence, beam_size=20):
"""
This function finds the `beam_size` leaf clusters that are closest to the
proposed sentence and returns the full path from the root to the cluster
along with the dot product between the sentence embedding and the
cluster centroid
Args:
sentence (string): input sentence for which to find clusters
beam_size (int): this is a beam size algorithm to explore the tree
Returns:
[([int], float)]: list of (path_from_root, score) sorted by score
"""
embed = self.compute_sentence_embeddings([sentence])[0].to("cpu")
active_paths = [([0], torch.dot(embed, self.node_list[0]["centroid"]).item())]
finished_paths = []
children_ids_list = [
[
self.nid_map[nid]
for nid in self.node_list[path[-1]]["children_ids"]
if nid in self.nid_map
]
for path, score in active_paths
]
while len(active_paths) > 0:
next_ids = sorted(
[
(
beam_id,
nid,
torch.dot(embed, self.node_list[nid]["centroid"]).item(),
)
for beam_id, children_ids in enumerate(children_ids_list)
for nid in children_ids
],
key=lambda x: x[2],
reverse=True,
)[:beam_size]
paths = [
(active_paths[beam_id][0] + [next_id], score)
for beam_id, next_id, score in next_ids
]
active_paths = []
for path, score in paths:
if (
len(
[
nid
for nid in self.node_list[path[-1]]["children_ids"]
if nid in self.nid_map
]
)
> 0
):
active_paths += [(path, score)]
else:
finished_paths += [(path, score)]
children_ids_list = [
[
self.nid_map[nid]
for nid in self.node_list[path[-1]]["children_ids"]
if nid in self.nid_map
]
for path, score in active_paths
]
return sorted(
finished_paths,
key=lambda x: x[-1],
reverse=True,
)[:beam_size]