recipes/lexicon_free/utilities/compute_lower_ppl_convlm.py (212 lines of code) (raw):
"""
Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.
This source code is licensed under the MIT-style license found in the
LICENSE file in the root directory of this source tree.
----------
Compute upper and lower limits on word perplexity for convlm models
Command : python3 compute_lower_ppl_convlm.py --model [...] --dict [...] \
--word_model [...] --word_dict [...] \
--text [...] --model_type char14B --dataset_type ls
Replace [...] with appropriate paths
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import numpy
import torch
from convlm_utils import (
EOS,
EOSIDX,
UNKIDX,
build_token_index_correspondence,
decodeInputText,
load_char_model_14B,
load_char_model_20B,
load_word_model,
)
from fairseq.data import Dictionary
from utils import prepare_vocabs_convlm, transform_asg, transform_asg_back
# reusing previous states for some reason is slower than reevaluating the full sentence.
# TODO speedup with batching and using previous state
def compute_word_logprob(model, current_state, target_word, token_index_dict):
if target_word == EOS:
x = torch.LongTensor(current_state).reshape(1, len(current_state)).cuda()
with torch.no_grad():
y = model.forward(x)[0]
logprobs = torch.nn.functional.log_softmax(y, 2).detach().cpu().numpy()[0]
return logprobs[-1, EOSIDX]
else:
additional_state = [token_index_dict[token] for token in list(target_word)]
with torch.no_grad():
x = (
torch.LongTensor(current_state + additional_state[:-1])
.reshape(1, len(current_state) + len(additional_state) - 1)
.cuda()
)
y = model.forward(x)[0]
logprobs = torch.nn.functional.log_softmax(y, 2).detach().cpu().numpy()[0]
return numpy.sum(
logprobs[-len(additional_state) :][
numpy.arange(len(additional_state)), additional_state
]
)
def compute_denominator(model, current_state, words, token_index_dict):
preds = [
compute_word_logprob(model, current_state, word, token_index_dict)
for word in words
]
assert len(preds) != 0, "Invalid denominator"
max_pred = numpy.max(preds)
return max_pred + numpy.log(numpy.sum(numpy.exp(preds - max_pred)))
def compute_words_model_pdf_mass(
word_probs, current_state_position, known_words, known_words_decoded
):
probs = word_probs[current_state_position, known_words_decoded]
indices = numpy.argsort(-probs)
# unk word is not added to this pdf mass, sometimes its prob is huge
# so take percentile from known word pdf
probs_sum = numpy.sum(probs)
top = numpy.where(numpy.cumsum(probs[indices]) > 0.95 * probs_sum)[0][0]
return [
transform_asg(w) + "|" if w != EOS else w for w in known_words[indices[:top]]
]
def compute_ppl_lower_limit(
model,
word_model,
sentences,
known_words,
known_words_original,
known_words_original_decoded,
indices_token_dict,
token_indices_dict,
):
n_words = 0
unk_n_words = 0
ppl = 0.0
ppl_lower = 0.0
n_logging = len(sentences)
for n, sentence in enumerate(sentences):
current_state = [EOSIDX]
current_word = ""
current_word_state_position = 0
addition_state = []
wordLM_sentence = (
"".join([indices_token_dict[idx] for idx in sentence])
.replace("|", " ")
.strip()
)
wordLM_sentence = [
transform_asg_back(word) for word in wordLM_sentence.split(" ")
]
wordLM_sentence_decoded = [EOSIDX] + [
UNKIDX if word not in word_indices_dict else word_indices_dict[word]
for word in wordLM_sentence
]
with torch.no_grad():
x = (
torch.LongTensor(wordLM_sentence_decoded)
.reshape(1, len(wordLM_sentence_decoded))
.cuda()
)
y = word_model.forward(x)[0]
words_probs = numpy.exp(
word_model.adaptive_softmax.get_log_prob(y, None)
.detach()
.cpu()
.numpy()[0]
)
for token_idx in sentence:
current_word += indices_token_dict[token_idx]
addition_state.append(token_idx)
if indices_token_dict[token_idx] == "|":
if current_word in known_words:
n_words += 1
pdf_mass_words = set(
compute_words_model_pdf_mass(
words_probs,
current_word_state_position,
known_words_original,
known_words_original_decoded,
)
)
if current_word not in pdf_mass_words:
pdf_mass_words.add(current_word)
word_score = compute_word_logprob(
model, current_state, current_word, token_indices_dict
)
ppl += word_score
ppl_lower += word_score - compute_denominator(
model, current_state, pdf_mass_words, token_indices_dict
)
else:
unk_n_words += 1
current_word = ""
current_state += addition_state
addition_state = []
current_word_state_position += 1
# process eos
word_score = compute_word_logprob(model, current_state, EOS, token_indices_dict)
n_words += 1
ppl += word_score
pdf_mass_words = set(
compute_words_model_pdf_mass(
words_probs,
current_word_state_position,
known_words_original,
known_words_original_decoded,
)
)
if EOS not in pdf_mass_words:
pdf_mass_words.add(EOS)
ppl_lower += word_score - compute_denominator(
model, current_state, pdf_mass_words, token_indices_dict
)
if n % 10 == 0:
print(
"Evaluated",
n,
"sentences among",
n_logging,
"upper limit perplexity",
numpy.exp(-ppl / n_words),
"lower limit perplexity",
numpy.exp(-ppl_lower / n_words),
"number of words",
n_words,
flush=True,
)
print("Final loss", ppl, "loss lower", ppl_lower)
print("Upper limit on perplexity:", numpy.exp(-ppl / n_words))
print("Lower limit on perplexity:", numpy.exp(-ppl_lower / n_words))
print("Total number of words:", n_words, "unknown words:", unk_n_words)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Upper and lower limits on word perplexity for convlm models"
)
parser.add_argument("--model", help="path to convlm model")
parser.add_argument("--word_model", help="path to convlm model")
parser.add_argument("--dict", help="path to convlm dict file in data")
parser.add_argument(
"--word_dict", help="path to convlm word convlm dict file in data"
)
parser.add_argument(
"--text", help="file to evaluate, in necessary format for model"
)
parser.add_argument("--model_type", help='"char14B" or "char20B"')
parser.add_argument("--dataset_type", help='"ls" or "wsj"', default="ls")
args = parser.parse_args()
print("Evaluate file {}".format(args.text))
token_indices_dict, indices_token_dict = build_token_index_correspondence(args.dict)
word_indices_dict, indices_word_dict = build_token_index_correspondence(
args.word_dict
)
known_words, known_words_original = prepare_vocabs_convlm(args.word_dict)
known_words_original_decoded = numpy.array(
[
UNKIDX if w not in word_indices_dict else word_indices_dict[w]
for w in known_words_original
]
)
with open(args.text, "r") as f:
sentences = [line.strip() for line in f]
input_data = decodeInputText(sentences, token_indices_dict)
fairseq_dict = Dictionary.load(args.dict)
word_fairseq_dict = Dictionary.load(args.word_dict)
word_model = load_word_model(args.word_model, word_fairseq_dict, args.dataset_type)
if "14B" in args.model_type:
char_model = load_char_model_14B(args.model, fairseq_dict, args.dataset_type)
else:
char_model = load_char_model_20B(args.model, fairseq_dict, args.dataset_type)
compute_ppl_lower_limit(
char_model,
word_model,
input_data,
known_words,
known_words_original,
known_words_original_decoded,
indices_token_dict,
token_indices_dict,
)