def cmd_rescore()

in src/mlm/cmds.py [0:0]


def cmd_rescore(args: argparse.Namespace) -> None:
    """rescore command

    You have two files with the following schema:
    {
        "<UTT_ID>": {
            "ref": {
                "score": 0.111
            },
            "hyp_1": {
                "score": 3.15
            },
            ...
        },
        ...
    }

    """

    model_list = args.model.split(',')
    pretrained_tup_list = [get_pretrained([mx.cpu(0)], model) for model in model_list]

    file_lm_list = args.file_lm.split(',')
    weight_list = [float(x) for x in args.weight.split(',')]

    assert len(pretrained_tup_list) == len(weight_list)

    preds_am = Predictions.from_file(args.file_am, max_utts=args.max_utts)
    preds_lm_list = [Predictions.from_file(Path(file_lm).open('r'), max_utts=args.max_utts, vocab=vocab, tokenizer=tokenizer) for file_lm, (_, vocab, tokenizer) in zip(file_lm_list, pretrained_tup_list)]

    # Preserves input order, but slower
    shared_keys = list(preds_am.keys())
    for preds_lm in preds_lm_list:
        # isdigit() suggests we're in automatic ID mode; cast to int
        preds_lm_keys = set(((int(key) if key.isdigit() else key) for key in preds_lm.keys()))
        shared_keys = [key for key in shared_keys if key in preds_lm_keys]

    logging.warn("{} shared keys found, rescoring these...".format(len(shared_keys)))

    preds_new = Predictions()
    for utt_id in shared_keys:
        hyps_am = preds_am[utt_id]
        hyps_lm_list = [preds_lm[str(utt_id)] for preds_lm in preds_lm_list]
        new_hyps = hyps_am.rescore(hyps_lm_list, scales=weight_list, ln=args.ln, ln_type=args.ln_type)
        preds_new[utt_id] = new_hyps

    preds_new.to_json(sys.stdout)

    # Compute WER after rescoring using the first file, if possible
    ref_file = args.ref_file
    if ref_file is None:
        ref_file = args.file_am
        ref_file.seek(0)

    if Path(ref_file.name).suffix == '.json':
        my_wer = _wer(ref_file, preds_new)
        logging.warn("WER: {}%".format(my_wer*100))
    else:
        my_bleu = _mbleu(ref_file, preds_new)