recipes/lexicon_free/utilities/compute_lower_ppl_kenlm.py (133 lines of code) (raw):

""" Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. This source code is licensed under the MIT-style license found in the LICENSE file in the root directory of this source tree. ---------- Compute upper and lower limits on word perplexity for kenlm ngram models Command : python3 compute_upper_ppl_kenlm.py --vocab_file [...] --text [...] \ --char_model [...] --word_model [...] Replace [...] with appropriate paths """ from __future__ import absolute_import, division, print_function, unicode_literals import argparse import kenlm import numpy from utils import EOS, prepare_vocabs, transform_asg, transform_asg_back LOG10 = numpy.log(10) def compute_word_logprob(model, current_state, target_word): word_pred = 0 if target_word == EOS: stateOut = kenlm.State() word_pred += model.BaseScore(current_state, str(target_word), stateOut) * LOG10 else: stateIn = current_state for token in list(target_word): stateOut = kenlm.State() word_pred += model.BaseScore(stateIn, token, stateOut) * LOG10 stateIn = stateOut return word_pred def compute_denominator(model, current_state, words): preds = [compute_word_logprob(model, current_state, word) for word in words] max_pred = numpy.max(preds) return max_pred + numpy.log(numpy.sum(numpy.exp(preds - max_pred))) def compute_words_model_pdf_mass(word_model, known_words, stateIn): probs = [] for word in known_words: stateOut = kenlm.State() probs.append( numpy.power(10, word_model.BaseScore(stateIn, str(word), stateOut)) ) probs_arr = numpy.array(probs) indices = numpy.argsort(-probs_arr) top = numpy.where(numpy.cumsum(probs_arr[indices]) > 0.95)[0][0] return [ transform_asg(w) + "|" if w != EOS else w for w in known_words[indices[:top]] ] def compute_ppl_lower_limit( model, word_model, sentences, known_words, known_words_original ): n_words = 0 unk_n_words = 0 ppl = 0.0 ppl_lower = 0.0 n_logging = len(sentences) for n, sentence in enumerate(sentences): stateIn = kenlm.State() word_stateIn = kenlm.State() model.BeginSentenceWrite(stateIn) word_model.BeginSentenceWrite(word_stateIn) current_word = "" word_score = 0.0 word_state = stateIn # state for char LM ending with exactly the previous word for token in sentence.split(" "): stateOut = kenlm.State() word_score += model.BaseScore(stateIn, token, stateOut) * LOG10 stateIn = stateOut current_word += token if token == "|": if current_word in known_words: n_words += 1 ppl += word_score pdf_mass_words = set( compute_words_model_pdf_mass( word_model, known_words_original, word_stateIn ) ) if current_word not in pdf_mass_words: pdf_mass_words.add(current_word) ppl_lower += compute_word_logprob( model, word_state, current_word ) - compute_denominator(model, word_state, pdf_mass_words) else: unk_n_words += 1 word_stateOut = kenlm.State() word_model.BaseScore( word_stateIn, transform_asg_back(current_word), word_stateOut ) word_stateIn = word_stateOut current_word = "" word_score = 0.0 word_state = stateOut stateOut = kenlm.State() n_words += 1 ppl += model.BaseScore(stateIn, EOS, stateOut) * LOG10 pdf_mass_words = set( compute_words_model_pdf_mass(word_model, known_words_original, word_stateIn) ) if EOS not in pdf_mass_words: pdf_mass_words.add(EOS) ppl_lower += compute_word_logprob(model, word_state, EOS) - compute_denominator( model, word_state, pdf_mass_words ) if n % 10 == 0: print( "Evaluated", n, "sentences among", n_logging, "upper limit perplexity", numpy.exp(-ppl / n_words), "lower limit perplexity", numpy.exp(-ppl_lower / n_words), "number of words", n_words, flush=True, ) print("Final loss", ppl, "loss lower", ppl_lower) print("Upper limit on perplexity:", numpy.exp(-ppl / n_words)) print("Lower limit on perplexity:", numpy.exp(-ppl_lower / n_words)) print("Total number of words:", n_words, "unknown words:", unk_n_words) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Upper and lower limit on word perplexity for kenlm char model" ) parser.add_argument( "--vocab_file", help="vocabulary of known words, use file " "from --limit_vocab_file during word kenLM training.", ) parser.add_argument( "--text", help="file to evaluate, prepared for char lm training" ) parser.add_argument("--char_model", help="kenlm char model") parser.add_argument("--word_model", help="kenlm word model") args = parser.parse_args() print("Evaluate file {}".format(args.text)) known_words, known_words_original = prepare_vocabs(args.vocab_file) with open(args.text, "r") as f: sentences = [line.strip() for line in f] word_model = kenlm.LanguageModel(args.word_model) char_model = kenlm.LanguageModel(args.char_model) compute_ppl_lower_limit( char_model, word_model, sentences, known_words, known_words_original )