recipes/lexicon_free/utilities/compute_upper_ppl_convlm.py (144 lines of code) (raw):
"""
Copyright (c) Facebook, Inc. and its affiliates.
All rights reserved.
This source code is licensed under the MIT-style license found in the
LICENSE file in the root directory of this source tree.
----------
Compute upper limit on word perplexity for convlm models
Command (for word) : python3 compute_upper_ppl_convlm.py --model [...] \
--dict [...] --text [...] --model_type word --dataset_type ls
Command (for char) : python3 compute_upper_ppl_convlm.py --model [...] \
--dict [...] --word_dict [...] --text [...] \
--model_type char14B --dataset_type ls
Command (for char) : python3 compute_upper_ppl_convlm.py --model [...] \
--dict [...] --word_dict [...] --text [...] \
--model_type char20B --dataset_type ls
Replace [...] with appropriate paths
"""
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import numpy
import torch
from convlm_utils import (
EOSIDX,
UNKIDX,
build_token_index_correspondence,
decodeInputText,
load_char_model_14B,
load_char_model_20B,
load_word_model,
)
from fairseq.data import Dictionary
from utils import transform_asg
def compute_ppl_upper_limit_char_convlm(
model,
input_charlm,
charLM_indices_token_dict,
charLM_token_indices_dict,
known_words,
):
sum_logp = 0
n_words = 0
sum_logp_known = 0
n_known_words = 0
sum_logp_unknown = 0
n_unknown_words = 0
n_letters = 0
for sentence in input_charlm:
x = torch.LongTensor([EOSIDX] + sentence).reshape(1, len(sentence) + 1).cuda()
with torch.no_grad():
y = model.forward(x)[0]
logprobs = torch.nn.functional.log_softmax(y, 2).detach().cpu().numpy()[0]
current_word = ""
word_ppl = 0.0
for index, token_id in enumerate(sentence):
n_letters += 1
current_word += charLM_indices_token_dict[token_id]
word_ppl += logprobs[index, token_id]
if charLM_indices_token_dict[token_id] == "|":
if current_word in known_words:
sum_logp_known += word_ppl
n_known_words += 1
else:
sum_logp_unknown += word_ppl
n_unknown_words += 1
current_word = ""
word_ppl = 0
sum_logp += numpy.sum(logprobs[numpy.arange(len(sentence)), sentence])
n_words += numpy.sum(numpy.array(sentence) == charLM_token_indices_dict["|"])
# add eos
sum_logp += logprobs[-1, EOSIDX]
n_words += 1
sum_logp_known += logprobs[-1, EOSIDX]
n_known_words += 1
n_letters += 1
loss_letter = -(sum_logp + sum_logp_unknown) / n_letters
ppl_word_no_unk = numpy.exp(-sum_logp_known / n_known_words)
ppl_word_unk = numpy.exp(-sum_logp_unknown / n_unknown_words)
assert n_known_words + n_unknown_words == n_words, "Error in words counting"
assert numpy.allclose(sum_logp, sum_logp_known + sum_logp_unknown), "Error in loss"
ppl_word = numpy.exp(-sum_logp / n_words)
print(
"Letter loss: {}, letter perplexity: {}".format(
loss_letter, numpy.exp(loss_letter)
)
)
print("Upper word perplexity for all words: {}".format(ppl_word))
print("Upper word perplexity for unknown words: {}".format(ppl_word_unk))
print(
"(Reported in the paper) "
"Upper word perplexity for known words: {}".format(ppl_word_no_unk)
)
def compute_ppl_upper_limit_word_convlm(model, input_wordlm):
sum_logp_known = 0
n_known_words = 0
sum_logp_unknown = 0
n_unknown_words = 0
for sentence in input_wordlm:
x = torch.LongTensor([EOSIDX] + sentence).reshape(1, len(sentence) + 1).cuda()
with torch.no_grad():
y = model.forward(x)[0]
logprobs = (
model.adaptive_softmax.get_log_prob(y, None).detach().cpu().numpy()[0]
)
for index, token_id in enumerate(sentence):
if token_id != UNKIDX:
sum_logp_known += logprobs[index, token_id]
n_known_words += 1
else:
sum_logp_unknown += logprobs[index, token_id]
n_unknown_words += 1
# add eos
sum_logp_known += logprobs[-1, EOSIDX]
n_known_words += 1
ppl_word_no_unk = numpy.exp(-sum_logp_known / n_known_words)
ppl_word_unk = numpy.exp(-sum_logp_unknown / n_unknown_words)
ppl_word = numpy.exp(
-(sum_logp_known + sum_logp_unknown) / (n_known_words + n_unknown_words)
)
print("Word perplexity for all words: {}".format(ppl_word))
print("Word perplexity for unknown words: {}".format(ppl_word_unk))
print(
"(Reported in the paper) "
"Word perplexity for known words: {}".format(ppl_word_no_unk)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Upper limit on word perplexity for convlm models"
)
parser.add_argument("--model", help="path to convlm model")
parser.add_argument("--dict", help="path to convlm dict file in data")
parser.add_argument(
"--text", help="file to evaluate, in necessary format for model"
)
parser.add_argument("--model_type", help='"word" or "char14B" or "char20B"')
parser.add_argument("--dataset_type", help='"ls" or "wsj"', default="ls")
parser.add_argument(
"--word_dict",
help="path to convlm word convlm dict file"
"in data (ignored for word model eval)",
default=None,
)
args = parser.parse_args()
print("Evaluate file {}".format(args.text))
token_indices_dict, indices_token_dict = build_token_index_correspondence(args.dict)
with open(args.text, "r") as f:
sentences = [line.strip() for line in f]
input_data = decodeInputText(sentences, token_indices_dict)
fairseq_dict = Dictionary.load(args.dict)
if args.model_type == "word":
model = load_word_model(args.model, fairseq_dict, args.dataset_type)
compute_ppl_upper_limit_word_convlm(model, input_data)
else:
with open(args.word_dict, "r") as f:
known_words = set(
[transform_asg(line.strip().split(" ")[0]) + "|" for line in f]
)
if "14B" in args.model_type:
model = load_char_model_14B(args.model, fairseq_dict, args.dataset_type)
else:
model = load_char_model_20B(args.model, fairseq_dict, args.dataset_type)
compute_ppl_upper_limit_char_convlm(
model, input_data, indices_token_dict, token_indices_dict, known_words
)