recipes/lexicon_free/utilities/compute_upper_ppl_convlm.py (144 lines of code) (raw):

""" Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. This source code is licensed under the MIT-style license found in the LICENSE file in the root directory of this source tree. ---------- Compute upper limit on word perplexity for convlm models Command (for word) : python3 compute_upper_ppl_convlm.py --model [...] \ --dict [...] --text [...] --model_type word --dataset_type ls Command (for char) : python3 compute_upper_ppl_convlm.py --model [...] \ --dict [...] --word_dict [...] --text [...] \ --model_type char14B --dataset_type ls Command (for char) : python3 compute_upper_ppl_convlm.py --model [...] \ --dict [...] --word_dict [...] --text [...] \ --model_type char20B --dataset_type ls Replace [...] with appropriate paths """ from __future__ import absolute_import, division, print_function, unicode_literals import argparse import numpy import torch from convlm_utils import ( EOSIDX, UNKIDX, build_token_index_correspondence, decodeInputText, load_char_model_14B, load_char_model_20B, load_word_model, ) from fairseq.data import Dictionary from utils import transform_asg def compute_ppl_upper_limit_char_convlm( model, input_charlm, charLM_indices_token_dict, charLM_token_indices_dict, known_words, ): sum_logp = 0 n_words = 0 sum_logp_known = 0 n_known_words = 0 sum_logp_unknown = 0 n_unknown_words = 0 n_letters = 0 for sentence in input_charlm: x = torch.LongTensor([EOSIDX] + sentence).reshape(1, len(sentence) + 1).cuda() with torch.no_grad(): y = model.forward(x)[0] logprobs = torch.nn.functional.log_softmax(y, 2).detach().cpu().numpy()[0] current_word = "" word_ppl = 0.0 for index, token_id in enumerate(sentence): n_letters += 1 current_word += charLM_indices_token_dict[token_id] word_ppl += logprobs[index, token_id] if charLM_indices_token_dict[token_id] == "|": if current_word in known_words: sum_logp_known += word_ppl n_known_words += 1 else: sum_logp_unknown += word_ppl n_unknown_words += 1 current_word = "" word_ppl = 0 sum_logp += numpy.sum(logprobs[numpy.arange(len(sentence)), sentence]) n_words += numpy.sum(numpy.array(sentence) == charLM_token_indices_dict["|"]) # add eos sum_logp += logprobs[-1, EOSIDX] n_words += 1 sum_logp_known += logprobs[-1, EOSIDX] n_known_words += 1 n_letters += 1 loss_letter = -(sum_logp + sum_logp_unknown) / n_letters ppl_word_no_unk = numpy.exp(-sum_logp_known / n_known_words) ppl_word_unk = numpy.exp(-sum_logp_unknown / n_unknown_words) assert n_known_words + n_unknown_words == n_words, "Error in words counting" assert numpy.allclose(sum_logp, sum_logp_known + sum_logp_unknown), "Error in loss" ppl_word = numpy.exp(-sum_logp / n_words) print( "Letter loss: {}, letter perplexity: {}".format( loss_letter, numpy.exp(loss_letter) ) ) print("Upper word perplexity for all words: {}".format(ppl_word)) print("Upper word perplexity for unknown words: {}".format(ppl_word_unk)) print( "(Reported in the paper) " "Upper word perplexity for known words: {}".format(ppl_word_no_unk) ) def compute_ppl_upper_limit_word_convlm(model, input_wordlm): sum_logp_known = 0 n_known_words = 0 sum_logp_unknown = 0 n_unknown_words = 0 for sentence in input_wordlm: x = torch.LongTensor([EOSIDX] + sentence).reshape(1, len(sentence) + 1).cuda() with torch.no_grad(): y = model.forward(x)[0] logprobs = ( model.adaptive_softmax.get_log_prob(y, None).detach().cpu().numpy()[0] ) for index, token_id in enumerate(sentence): if token_id != UNKIDX: sum_logp_known += logprobs[index, token_id] n_known_words += 1 else: sum_logp_unknown += logprobs[index, token_id] n_unknown_words += 1 # add eos sum_logp_known += logprobs[-1, EOSIDX] n_known_words += 1 ppl_word_no_unk = numpy.exp(-sum_logp_known / n_known_words) ppl_word_unk = numpy.exp(-sum_logp_unknown / n_unknown_words) ppl_word = numpy.exp( -(sum_logp_known + sum_logp_unknown) / (n_known_words + n_unknown_words) ) print("Word perplexity for all words: {}".format(ppl_word)) print("Word perplexity for unknown words: {}".format(ppl_word_unk)) print( "(Reported in the paper) " "Word perplexity for known words: {}".format(ppl_word_no_unk) ) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Upper limit on word perplexity for convlm models" ) parser.add_argument("--model", help="path to convlm model") parser.add_argument("--dict", help="path to convlm dict file in data") parser.add_argument( "--text", help="file to evaluate, in necessary format for model" ) parser.add_argument("--model_type", help='"word" or "char14B" or "char20B"') parser.add_argument("--dataset_type", help='"ls" or "wsj"', default="ls") parser.add_argument( "--word_dict", help="path to convlm word convlm dict file" "in data (ignored for word model eval)", default=None, ) args = parser.parse_args() print("Evaluate file {}".format(args.text)) token_indices_dict, indices_token_dict = build_token_index_correspondence(args.dict) with open(args.text, "r") as f: sentences = [line.strip() for line in f] input_data = decodeInputText(sentences, token_indices_dict) fairseq_dict = Dictionary.load(args.dict) if args.model_type == "word": model = load_word_model(args.model, fairseq_dict, args.dataset_type) compute_ppl_upper_limit_word_convlm(model, input_data) else: with open(args.word_dict, "r") as f: known_words = set( [transform_asg(line.strip().split(" ")[0]) + "|" for line in f] ) if "14B" in args.model_type: model = load_char_model_14B(args.model, fairseq_dict, args.dataset_type) else: model = load_char_model_20B(args.model, fairseq_dict, args.dataset_type) compute_ppl_upper_limit_char_convlm( model, input_data, indices_token_dict, token_indices_dict, known_words )