recipes/lexicon_free/utilities/compute_lower_ppl_kenlm.py (133 lines of code) (raw):
"""
Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.
This source code is licensed under the MIT-style license found in the
LICENSE file in the root directory of this source tree.
----------
Compute upper and lower limits on word perplexity for kenlm ngram models
Command : python3 compute_upper_ppl_kenlm.py --vocab_file [...] --text [...] \
--char_model [...] --word_model [...]
Replace [...] with appropriate paths
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import kenlm
import numpy
from utils import EOS, prepare_vocabs, transform_asg, transform_asg_back
LOG10 = numpy.log(10)
def compute_word_logprob(model, current_state, target_word):
word_pred = 0
if target_word == EOS:
stateOut = kenlm.State()
word_pred += model.BaseScore(current_state, str(target_word), stateOut) * LOG10
else:
stateIn = current_state
for token in list(target_word):
stateOut = kenlm.State()
word_pred += model.BaseScore(stateIn, token, stateOut) * LOG10
stateIn = stateOut
return word_pred
def compute_denominator(model, current_state, words):
preds = [compute_word_logprob(model, current_state, word) for word in words]
max_pred = numpy.max(preds)
return max_pred + numpy.log(numpy.sum(numpy.exp(preds - max_pred)))
def compute_words_model_pdf_mass(word_model, known_words, stateIn):
probs = []
for word in known_words:
stateOut = kenlm.State()
probs.append(
numpy.power(10, word_model.BaseScore(stateIn, str(word), stateOut))
)
probs_arr = numpy.array(probs)
indices = numpy.argsort(-probs_arr)
top = numpy.where(numpy.cumsum(probs_arr[indices]) > 0.95)[0][0]
return [
transform_asg(w) + "|" if w != EOS else w for w in known_words[indices[:top]]
]
def compute_ppl_lower_limit(
model, word_model, sentences, known_words, known_words_original
):
n_words = 0
unk_n_words = 0
ppl = 0.0
ppl_lower = 0.0
n_logging = len(sentences)
for n, sentence in enumerate(sentences):
stateIn = kenlm.State()
word_stateIn = kenlm.State()
model.BeginSentenceWrite(stateIn)
word_model.BeginSentenceWrite(word_stateIn)
current_word = ""
word_score = 0.0
word_state = stateIn # state for char LM ending with exactly the previous word
for token in sentence.split(" "):
stateOut = kenlm.State()
word_score += model.BaseScore(stateIn, token, stateOut) * LOG10
stateIn = stateOut
current_word += token
if token == "|":
if current_word in known_words:
n_words += 1
ppl += word_score
pdf_mass_words = set(
compute_words_model_pdf_mass(
word_model, known_words_original, word_stateIn
)
)
if current_word not in pdf_mass_words:
pdf_mass_words.add(current_word)
ppl_lower += compute_word_logprob(
model, word_state, current_word
) - compute_denominator(model, word_state, pdf_mass_words)
else:
unk_n_words += 1
word_stateOut = kenlm.State()
word_model.BaseScore(
word_stateIn, transform_asg_back(current_word), word_stateOut
)
word_stateIn = word_stateOut
current_word = ""
word_score = 0.0
word_state = stateOut
stateOut = kenlm.State()
n_words += 1
ppl += model.BaseScore(stateIn, EOS, stateOut) * LOG10
pdf_mass_words = set(
compute_words_model_pdf_mass(word_model, known_words_original, word_stateIn)
)
if EOS not in pdf_mass_words:
pdf_mass_words.add(EOS)
ppl_lower += compute_word_logprob(model, word_state, EOS) - compute_denominator(
model, word_state, pdf_mass_words
)
if n % 10 == 0:
print(
"Evaluated",
n,
"sentences among",
n_logging,
"upper limit perplexity",
numpy.exp(-ppl / n_words),
"lower limit perplexity",
numpy.exp(-ppl_lower / n_words),
"number of words",
n_words,
flush=True,
)
print("Final loss", ppl, "loss lower", ppl_lower)
print("Upper limit on perplexity:", numpy.exp(-ppl / n_words))
print("Lower limit on perplexity:", numpy.exp(-ppl_lower / n_words))
print("Total number of words:", n_words, "unknown words:", unk_n_words)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Upper and lower limit on word perplexity for kenlm char model"
)
parser.add_argument(
"--vocab_file",
help="vocabulary of known words, use file "
"from --limit_vocab_file during word kenLM training.",
)
parser.add_argument(
"--text", help="file to evaluate, prepared for char lm training"
)
parser.add_argument("--char_model", help="kenlm char model")
parser.add_argument("--word_model", help="kenlm word model")
args = parser.parse_args()
print("Evaluate file {}".format(args.text))
known_words, known_words_original = prepare_vocabs(args.vocab_file)
with open(args.text, "r") as f:
sentences = [line.strip() for line in f]
word_model = kenlm.LanguageModel(args.word_model)
char_model = kenlm.LanguageModel(args.char_model)
compute_ppl_lower_limit(
char_model, word_model, sentences, known_words, known_words_original
)