in src/datatuner/lm/evaluate.py [0:0]
def clean_beam(self, cons_classifier, cons_cache, cons_dataset):
new_heap = []
# dataset_fields
largest_prob = heapq.nlargest(1, self.heap)[0][0]
if cons_classifier is not None:
mr_key = dataset_fields[cons_dataset]["data"]
# Create the candidates for consistency classifier
consistency_cands = []
beam_ind_to_cand_ind = {}
# All beam elements are included
for i in range(len(self.heap)):
# Only consider complete beam elements
if self.heap[i][1] is not True: # and self.heap[i][0] == largest_prob:
continue
# Create a map to know which beam element was included in the beam
cand = self.get_cons_rep(self.heap[i][3], mr_key)
# Avoid classifying what we already classified
if str(cand) in cons_cache:
continue
else:
beam_ind_to_cand_ind[i] = len(consistency_cands)
consistency_cands.append(cand)
# cons_cache[str(cand)] = True
if len(consistency_cands) > 0:
# Results format: {"preds":[...], "preds_prob":[...]}
cons_results = cons_classifier.evaluate(consistency_cands)
all_cons_results = []
for i in range(len(self.heap)):
cand = self.get_cons_rep(self.heap[i][3], mr_key)
if str(cand) in cons_cache:
all_cons_results.append(cons_cache[str(cand)])
elif i in beam_ind_to_cand_ind:
cand_ind = beam_ind_to_cand_ind[i]
cons_for_item = {
"pred": cons_results["preds"][cand_ind],
"prob": cons_results["preds_prob"][cand_ind],
}
all_cons_results.append(cons_for_item)
cons_cache[str(cand)] = cons_for_item
else:
all_cons_results.append(None)
for i, item in enumerate(self.heap):
if cons_classifier is not None:
# Only if there was a consistency result for this item
if all_cons_results[i] is not None:
pred, pred_prob = all_cons_results[i]["pred"], all_cons_results[i]["prob"]
# Save the prediction
self.heap[i][3]["cons_prediction"] = {"pred": pred, "prob": pred_prob}
if pred in ["omission"] and pred_prob > 0.5:
if DEBUG:
logger.debug(
f"removed {consistency_cands[cand_ind]} as the prediction was {pred} with probability {pred_prob}"
)
continue
elif pred == "accurate" and pred_prob > 0.5:
# Add a factor to not allow correct to be removed during cleaning based on lowest probability
self.heap[i] = list(self.heap[i])
self.heap[i][0] = 1000 + self.heap[i][0]
self.heap[i] = tuple(self.heap[i])
# Remove beam components with probability lower than 1/max_ratio_in_heap times the highest beam component probability
if True or item[0] > largest_prob / self.max_ratio_in_heap:
new_heap.append(item)
self.heap = new_heap
while len(self.heap) > self.beam_width:
prob, _, _, payload = heapq.heappop(self.heap)
if DEBUG:
logger.debug("removing")
logger.debug(self.tokenizer.decode(payload["prefix"]))
logger.debug(prob)
logger.debug(payload["all_probs"])
logger.debug("\n")