pipeline/translate/extract_best.py (144 lines of code) (raw):
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from __future__ import absolute_import, division, print_function
import argparse
import collections
import math
import re
import sys
def main():
args = parse_args()
if args.metric == "bleu":
score_function = compute_bleu
elif args.metric == "sacrebleu":
global sacrebleu
import sacrebleu
score_function = compute_sacrebleu
elif args.metric == "chrf":
global sacrebleu
import sacrebleu
score_function = compute_chrf
else:
sys.stderr.write("Unrecognized metric: {}\n".format(args.metric))
pass
if args.toolkit == "marian":
marian_best_bleu(args, score_function)
elif args.toolkit == "t2t":
t2t_best_bleu(args, score_function)
pass
def t2t_best_bleu(args, score_function):
for i, ref_line in enumerate(args.references):
refs = ref_line.strip().split("\n")
if args.debpe:
refs = [re.sub(r"@@ +", "", r) for r in refs]
pass
texts = next(args.nbest).strip().split("\t")
if args.debpe:
texts = [re.sub(r"@@ +", "", t) for t in texts]
pass
refs = [r.split() for r in refs]
scores = [score_function(refs, t.split()) for t in texts]
best_txt = texts[scores.index(max(scores))]
args.output.write("{}\n".format(best_txt))
if args.debug:
sys.stderr.write("{}: {}\n".format(i, scores))
pass
if i % 100000 == 0 and i > 0:
sys.stderr.write("[{}]\n".format(i))
pass
pass
def marian_best_bleu(args, score_function):
prev_line = None
for i, ref_line in enumerate(args.references):
refs = ref_line.strip().split("\n")
if args.debpe:
refs = [re.sub(r"@@ +", "", r) for r in refs]
texts = []
while True:
if prev_line:
# CTranslate2 can output empty text, for example:
# 10181 ||| .GDFMAKERPROJECTファイルを開くには?
# 10181 ||| .GDMAKERPROJECTファイルを開くには?
# 10181 ||| .GDFMAKERPROJECTファイルを開くには?
# 10181 ||| .GDFMakerPROJECTファイルを開くには?
# 10181 ||| .GDFAKERPROJECTファイルを開くには?
# 10181 ||| .GDMakerPROJECTファイルを開くには?
# 10181 ||| .GDFMAKERPROJECTファイルを開くには。
# 10181 |||
# Marian also outputs scores, for example:
# 0 ||| Реформа, направленная на выдвижение условий, идет слишком медленно. ||| F0= -9.21191 F1= -11.53 ||| -1.22059
fields = prev_line.rstrip("\n").split(" ||| ")
if len(fields) == 1:
# handle "10181 |||"
fields = fields[0].split()[0], ""
idx = int(fields[0])
if idx == i:
texts.append(fields[1])
else:
break
prev_line = next(args.nbest, None)
if not prev_line:
break
if args.debpe:
texts = [re.sub(r"@@ +", "", t) for t in texts]
refs = [r.split() for r in refs]
scores = [score_function(refs, t.split()) for t in texts]
best_txt = texts[scores.index(max(scores))]
args.output.write("{}\n".format(best_txt))
if args.debug:
sys.stderr.write("{}: {}\n".format(i, scores))
if i % 100000 == 0 and i > 0:
sys.stderr.write("[{}]\n".format(i))
def compute_chrf(references, translation):
hypo = " ".join(translation)
refs = [" ".join(r) for r in references]
return sacrebleu.sentence_chrf(hypo, refs).score
def compute_sacrebleu(references, translation):
hypo = " ".join(translation)
refs = [" ".join(r) for r in references]
return sacrebleu.sentence_bleu(hypo, refs).score
def compute_bleu(references, translation, max_order=4):
precisions = get_ngram_precisions(references, translation, max_order)
if min(precisions) > 0:
p_log_sum = sum((1.0 / max_order) * math.log(p) for p in precisions)
geo_mean = math.exp(p_log_sum)
else:
geo_mean = 0
bp = get_brevity_penalty(references, translation)
return geo_mean * bp
def get_brevity_penalty(references, translation):
reference_length = min(len(r) for r in references)
translation_length = len(translation)
ratio = float(translation_length) / reference_length
if ratio > 1.0 or ratio == 0.0:
bp = 1.0
else:
bp = math.exp(1 - 1.0 / ratio)
return bp
def get_ngram_precisions(references, translation, max_order=4):
matches_by_order = [0] * max_order
possible_matches_by_order = [0] * max_order
merged_ref_ngram_counts = collections.Counter()
for reference in references:
merged_ref_ngram_counts |= get_ngrams(reference, max_order)
translation_ngram_counts = get_ngrams(translation, max_order)
overlap = translation_ngram_counts & merged_ref_ngram_counts
for ngram in overlap:
matches_by_order[len(ngram) - 1] += overlap[ngram]
for order in range(1, max_order + 1):
possible_matches = len(translation) - order + 1
if possible_matches > 0:
possible_matches_by_order[order - 1] += possible_matches
precisions = [0] * max_order
for i in range(0, max_order):
# smoothing
if matches_by_order[i] == 0 and possible_matches_by_order[i] == 0:
precisions[i] = 0.0
else:
precisions[i] = (matches_by_order[i] + 1.0) / (possible_matches_by_order[i] + 1.0)
return precisions
def get_ngrams(segment, max_order):
ngram_counts = collections.Counter()
for order in range(1, max_order + 1):
for i in range(0, len(segment) - order + 1):
ngram = tuple(segment[i : i + order])
ngram_counts[ngram] += 1
return ngram_counts
def parse_args():
from argparse import FileType
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--nbest", type=FileType("r"), default=sys.stdin)
parser.add_argument("-r", "--references", type=FileType("r"), required=True)
parser.add_argument("-o", "--output", type=FileType("w"), default=sys.stdout)
parser.add_argument("-m", "--metric", default="bleu")
parser.add_argument("--debpe", action="store_true")
parser.add_argument("-d", "--debug", action="store_true")
parser.add_argument("-t", "--toolkit", default="marian", help="Toolkit: 'marian' or 't2t'")
return parser.parse_args()
if __name__ == "__main__":
main()