in src/mlm/cmds.py [0:0]
def cmd_rescore(args: argparse.Namespace) -> None:
"""rescore command
You have two files with the following schema:
{
"<UTT_ID>": {
"ref": {
"score": 0.111
},
"hyp_1": {
"score": 3.15
},
...
},
...
}
"""
model_list = args.model.split(',')
pretrained_tup_list = [get_pretrained([mx.cpu(0)], model) for model in model_list]
file_lm_list = args.file_lm.split(',')
weight_list = [float(x) for x in args.weight.split(',')]
assert len(pretrained_tup_list) == len(weight_list)
preds_am = Predictions.from_file(args.file_am, max_utts=args.max_utts)
preds_lm_list = [Predictions.from_file(Path(file_lm).open('r'), max_utts=args.max_utts, vocab=vocab, tokenizer=tokenizer) for file_lm, (_, vocab, tokenizer) in zip(file_lm_list, pretrained_tup_list)]
# Preserves input order, but slower
shared_keys = list(preds_am.keys())
for preds_lm in preds_lm_list:
# isdigit() suggests we're in automatic ID mode; cast to int
preds_lm_keys = set(((int(key) if key.isdigit() else key) for key in preds_lm.keys()))
shared_keys = [key for key in shared_keys if key in preds_lm_keys]
logging.warn("{} shared keys found, rescoring these...".format(len(shared_keys)))
preds_new = Predictions()
for utt_id in shared_keys:
hyps_am = preds_am[utt_id]
hyps_lm_list = [preds_lm[str(utt_id)] for preds_lm in preds_lm_list]
new_hyps = hyps_am.rescore(hyps_lm_list, scales=weight_list, ln=args.ln, ln_type=args.ln_type)
preds_new[utt_id] = new_hyps
preds_new.to_json(sys.stdout)
# Compute WER after rescoring using the first file, if possible
ref_file = args.ref_file
if ref_file is None:
ref_file = args.file_am
ref_file.seek(0)
if Path(ref_file.name).suffix == '.json':
my_wer = _wer(ref_file, preds_new)
logging.warn("WER: {}%".format(my_wer*100))
else:
my_bleu = _mbleu(ref_file, preds_new)