def bleu()

in src/datatuner/lm/metrics.py [0:0]


def bleu(original, current, all_outputs, final, case_insensitive=True, all_keys=None):
    """Computes bleu score for the values of the given keys in the list of dictionaries `all_outputs`"""
    if len(all_outputs) == 0:
        return {"value": 0, "count": 0}

    from sacrebleu import corpus_bleu

    def process(s):
        return s.lower() if case_insensitive else s

    # group by all the other keys
    all_outputs = copy.deepcopy(all_outputs)
    if all_keys is None:
        keys = all_outputs[0].keys()
    else:
        keys = all_keys
        print(keys)

    other_keys = list(set([key for key in keys if key not in [original, current]]))

    group = {}
    max_refs = 1
    for item in all_outputs:
        # other inputs concatenated
        search_key = str([item[x] for x in other_keys if x in item])
        if type(item[current]) == list:
            item[current] = item[current][0]

        current_val = process(item[current])
        original_val = process(item[original])

        if search_key in group:
            group[search_key]["references"].append(original_val)
            group[search_key]["prediction"] = current_val
            if len(group[search_key]["references"]) > max_refs:
                max_refs = len(group[search_key]["references"])
        else:
            group[search_key] = {"references": [original_val], "prediction": current_val}

    all_predictions = []
    all_references = [[] for i in range(max_refs)]

    for item in group.values():
        all_predictions.append(item["prediction"])
        for i in range(max_refs):
            try:
                all_references[i].append(item["references"][i])
            except:
                all_references[i].append("")

    e2e_metrics = {}
    if final:
        e2e_metrics = get_e2e_metrics(all_predictions, all_references)
    e2e_metrics.update({"value": corpus_bleu(all_predictions, all_references).score, "count": len(all_predictions)})
    return e2e_metrics