def find_cluster_beam()

in data_measurements/embeddings/embeddings.py [0:0]


    def find_cluster_beam(self, sentence, beam_size=20):
        """
        This function finds the `beam_size` leaf clusters that are closest to the
        proposed sentence and returns the full path from the root to the cluster
        along with the dot product between the sentence embedding and the
        cluster centroid
        Args:
            sentence (string): input sentence for which to find clusters
            beam_size (int): this is a beam size algorithm to explore the tree
        Returns:
            [([int], float)]: list of (path_from_root, score) sorted by score
        """
        embed = self.compute_sentence_embeddings([sentence])[0].to("cpu")
        active_paths = [([0], torch.dot(embed, self.node_list[0]["centroid"]).item())]
        finished_paths = []
        children_ids_list = [
            [
                self.nid_map[nid]
                for nid in self.node_list[path[-1]]["children_ids"]
                if nid in self.nid_map
            ]
            for path, score in active_paths
        ]
        while len(active_paths) > 0:
            next_ids = sorted(
                [
                    (
                        beam_id,
                        nid,
                        torch.dot(embed, self.node_list[nid]["centroid"]).item(),
                    )
                    for beam_id, children_ids in enumerate(children_ids_list)
                    for nid in children_ids
                ],
                key=lambda x: x[2],
                reverse=True,
            )[:beam_size]
            paths = [
                (active_paths[beam_id][0] + [next_id], score)
                for beam_id, next_id, score in next_ids
            ]
            active_paths = []
            for path, score in paths:
                if (
                    len(
                        [
                            nid
                            for nid in self.node_list[path[-1]]["children_ids"]
                            if nid in self.nid_map
                        ]
                    )
                    > 0
                ):
                    active_paths += [(path, score)]
                else:
                    finished_paths += [(path, score)]
            children_ids_list = [
                [
                    self.nid_map[nid]
                    for nid in self.node_list[path[-1]]["children_ids"]
                    if nid in self.nid_map
                ]
                for path, score in active_paths
            ]
        return sorted(
            finished_paths,
            key=lambda x: x[-1],
            reverse=True,
        )[:beam_size]