def clean_beam()

in src/datatuner/lm/evaluate.py [0:0]


    def clean_beam(self, cons_classifier, cons_cache, cons_dataset):
        new_heap = []

        # dataset_fields
        largest_prob = heapq.nlargest(1, self.heap)[0][0]

        if cons_classifier is not None:
            mr_key = dataset_fields[cons_dataset]["data"]
            # Create the candidates for consistency classifier
            consistency_cands = []
            beam_ind_to_cand_ind = {}
            # All beam elements are included
            for i in range(len(self.heap)):
                # Only consider complete beam elements
                if self.heap[i][1] is not True:  # and self.heap[i][0] == largest_prob:
                    continue
                # Create a map to know which beam element was included in the beam
                cand = self.get_cons_rep(self.heap[i][3], mr_key)
                # Avoid classifying what we already classified
                if str(cand) in cons_cache:
                    continue
                else:
                    beam_ind_to_cand_ind[i] = len(consistency_cands)
                    consistency_cands.append(cand)
                    # cons_cache[str(cand)] = True

            if len(consistency_cands) > 0:
                # Results format: {"preds":[...], "preds_prob":[...]}
                cons_results = cons_classifier.evaluate(consistency_cands)

            all_cons_results = []

            for i in range(len(self.heap)):
                cand = self.get_cons_rep(self.heap[i][3], mr_key)

                if str(cand) in cons_cache:
                    all_cons_results.append(cons_cache[str(cand)])
                elif i in beam_ind_to_cand_ind:
                    cand_ind = beam_ind_to_cand_ind[i]
                    cons_for_item = {
                        "pred": cons_results["preds"][cand_ind],
                        "prob": cons_results["preds_prob"][cand_ind],
                    }
                    all_cons_results.append(cons_for_item)
                    cons_cache[str(cand)] = cons_for_item
                else:
                    all_cons_results.append(None)

        for i, item in enumerate(self.heap):
            if cons_classifier is not None:
                # Only if there was a consistency result for this item
                if all_cons_results[i] is not None:
                    pred, pred_prob = all_cons_results[i]["pred"], all_cons_results[i]["prob"]
                    # Save the prediction
                    self.heap[i][3]["cons_prediction"] = {"pred": pred, "prob": pred_prob}
                    if pred in ["omission"] and pred_prob > 0.5:
                        if DEBUG:
                            logger.debug(
                                f"removed {consistency_cands[cand_ind]} as the prediction was {pred} with probability {pred_prob}"
                            )
                        continue
                    elif pred == "accurate" and pred_prob > 0.5:
                        # Add a factor to not allow correct to be removed during cleaning based on lowest probability
                        self.heap[i] = list(self.heap[i])
                        self.heap[i][0] = 1000 + self.heap[i][0]
                        self.heap[i] = tuple(self.heap[i])

            # Remove beam components with probability lower than 1/max_ratio_in_heap times the highest beam component probability
            if True or item[0] > largest_prob / self.max_ratio_in_heap:
                new_heap.append(item)
        self.heap = new_heap

        while len(self.heap) > self.beam_width:
            prob, _, _, payload = heapq.heappop(self.heap)
            if DEBUG:
                logger.debug("removing")
                logger.debug(self.tokenizer.decode(payload["prefix"]))
                logger.debug(prob)
                logger.debug(payload["all_probs"])
                logger.debug("\n")