in is21_deep_bias/score.py [0:0]
def main(args):
refs = {}
with open(args.refs, "r") as f:
for line in f:
ary = line.strip().split("\t")
uttid, ref, biasing_words = ary[0], ary[1], set(json.loads(ary[2]))
refs[uttid] = {"text": ref, "biasing_words": biasing_words}
logger.info("Loaded %d reference utts from %s", len(refs), args.refs)
hyps = {}
with open(args.hyps, "r") as f:
for line in f:
ary = line.strip().split("\t")
# May have empty hypo
if len(ary) >= 2:
uttid, hyp = ary[0], ary[1]
else:
uttid, hyp = ary[0], ""
hyps[uttid] = hyp
logger.info("Loaded %d hypothesis utts from %s", len(hyps), args.hyps)
if not args.lenient:
for uttid in refs:
if uttid in hyps:
continue
raise ValueError(
f"{uttid} missing in hyps! Set `--lenient` flag to ignore this error."
)
# Calculate WER, U-WER, and B-WER
wer = WordError()
u_wer = WordError()
b_wer = WordError()
for uttid in refs:
if uttid not in hyps:
continue
ref_tokens = refs[uttid]["text"].split()
biasing_words = refs[uttid]["biasing_words"]
hyp_tokens = hyps[uttid].split()
ed = EditDistance()
result = ed.align(ref_tokens, hyp_tokens)
for code, ref_idx, hyp_idx in zip(result.codes, result.refs, result.hyps):
if code == Code.match:
wer.ref_words += 1
if ref_tokens[ref_idx] in biasing_words:
b_wer.ref_words += 1
else:
u_wer.ref_words += 1
elif code == Code.substitution:
wer.ref_words += 1
wer.errors[Code.substitution] += 1
if ref_tokens[ref_idx] in biasing_words:
b_wer.ref_words += 1
b_wer.errors[Code.substitution] += 1
else:
u_wer.ref_words += 1
u_wer.errors[Code.substitution] += 1
elif code == Code.deletion:
wer.ref_words += 1
wer.errors[Code.deletion] += 1
if ref_tokens[ref_idx] in biasing_words:
b_wer.ref_words += 1
b_wer.errors[Code.deletion] += 1
else:
u_wer.ref_words += 1
u_wer.errors[Code.deletion] += 1
elif code == Code.insertion:
wer.errors[Code.insertion] += 1
if hyp_tokens[hyp_idx] in biasing_words:
b_wer.errors[Code.insertion] += 1
else:
u_wer.errors[Code.insertion] += 1
# Report results
print(f"WER: {wer.get_result_string()}")
print(f"U-WER: {u_wer.get_result_string()}")
print(f"B-WER: {b_wer.get_result_string()}")