in src/datatuner/lm/metrics.py [0:0]
def bleu(original, current, all_outputs, final, case_insensitive=True, all_keys=None):
"""Computes bleu score for the values of the given keys in the list of dictionaries `all_outputs`"""
if len(all_outputs) == 0:
return {"value": 0, "count": 0}
from sacrebleu import corpus_bleu
def process(s):
return s.lower() if case_insensitive else s
# group by all the other keys
all_outputs = copy.deepcopy(all_outputs)
if all_keys is None:
keys = all_outputs[0].keys()
else:
keys = all_keys
print(keys)
other_keys = list(set([key for key in keys if key not in [original, current]]))
group = {}
max_refs = 1
for item in all_outputs:
# other inputs concatenated
search_key = str([item[x] for x in other_keys if x in item])
if type(item[current]) == list:
item[current] = item[current][0]
current_val = process(item[current])
original_val = process(item[original])
if search_key in group:
group[search_key]["references"].append(original_val)
group[search_key]["prediction"] = current_val
if len(group[search_key]["references"]) > max_refs:
max_refs = len(group[search_key]["references"])
else:
group[search_key] = {"references": [original_val], "prediction": current_val}
all_predictions = []
all_references = [[] for i in range(max_refs)]
for item in group.values():
all_predictions.append(item["prediction"])
for i in range(max_refs):
try:
all_references[i].append(item["references"][i])
except:
all_references[i].append("")
e2e_metrics = {}
if final:
e2e_metrics = get_e2e_metrics(all_predictions, all_references)
e2e_metrics.update({"value": corpus_bleu(all_predictions, all_references).score, "count": len(all_predictions)})
return e2e_metrics