recipes/lexicon_free/utilities/compute_lower_ppl_convlm.py (212 lines of code) (raw):

""" Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. This source code is licensed under the MIT-style license found in the LICENSE file in the root directory of this source tree. ---------- Compute upper and lower limits on word perplexity for convlm models Command : python3 compute_lower_ppl_convlm.py --model [...] --dict [...] \ --word_model [...] --word_dict [...] \ --text [...] --model_type char14B --dataset_type ls Replace [...] with appropriate paths """ from __future__ import absolute_import, division, print_function, unicode_literals import argparse import numpy import torch from convlm_utils import ( EOS, EOSIDX, UNKIDX, build_token_index_correspondence, decodeInputText, load_char_model_14B, load_char_model_20B, load_word_model, ) from fairseq.data import Dictionary from utils import prepare_vocabs_convlm, transform_asg, transform_asg_back # reusing previous states for some reason is slower than reevaluating the full sentence. # TODO speedup with batching and using previous state def compute_word_logprob(model, current_state, target_word, token_index_dict): if target_word == EOS: x = torch.LongTensor(current_state).reshape(1, len(current_state)).cuda() with torch.no_grad(): y = model.forward(x)[0] logprobs = torch.nn.functional.log_softmax(y, 2).detach().cpu().numpy()[0] return logprobs[-1, EOSIDX] else: additional_state = [token_index_dict[token] for token in list(target_word)] with torch.no_grad(): x = ( torch.LongTensor(current_state + additional_state[:-1]) .reshape(1, len(current_state) + len(additional_state) - 1) .cuda() ) y = model.forward(x)[0] logprobs = torch.nn.functional.log_softmax(y, 2).detach().cpu().numpy()[0] return numpy.sum( logprobs[-len(additional_state) :][ numpy.arange(len(additional_state)), additional_state ] ) def compute_denominator(model, current_state, words, token_index_dict): preds = [ compute_word_logprob(model, current_state, word, token_index_dict) for word in words ] assert len(preds) != 0, "Invalid denominator" max_pred = numpy.max(preds) return max_pred + numpy.log(numpy.sum(numpy.exp(preds - max_pred))) def compute_words_model_pdf_mass( word_probs, current_state_position, known_words, known_words_decoded ): probs = word_probs[current_state_position, known_words_decoded] indices = numpy.argsort(-probs) # unk word is not added to this pdf mass, sometimes its prob is huge # so take percentile from known word pdf probs_sum = numpy.sum(probs) top = numpy.where(numpy.cumsum(probs[indices]) > 0.95 * probs_sum)[0][0] return [ transform_asg(w) + "|" if w != EOS else w for w in known_words[indices[:top]] ] def compute_ppl_lower_limit( model, word_model, sentences, known_words, known_words_original, known_words_original_decoded, indices_token_dict, token_indices_dict, ): n_words = 0 unk_n_words = 0 ppl = 0.0 ppl_lower = 0.0 n_logging = len(sentences) for n, sentence in enumerate(sentences): current_state = [EOSIDX] current_word = "" current_word_state_position = 0 addition_state = [] wordLM_sentence = ( "".join([indices_token_dict[idx] for idx in sentence]) .replace("|", " ") .strip() ) wordLM_sentence = [ transform_asg_back(word) for word in wordLM_sentence.split(" ") ] wordLM_sentence_decoded = [EOSIDX] + [ UNKIDX if word not in word_indices_dict else word_indices_dict[word] for word in wordLM_sentence ] with torch.no_grad(): x = ( torch.LongTensor(wordLM_sentence_decoded) .reshape(1, len(wordLM_sentence_decoded)) .cuda() ) y = word_model.forward(x)[0] words_probs = numpy.exp( word_model.adaptive_softmax.get_log_prob(y, None) .detach() .cpu() .numpy()[0] ) for token_idx in sentence: current_word += indices_token_dict[token_idx] addition_state.append(token_idx) if indices_token_dict[token_idx] == "|": if current_word in known_words: n_words += 1 pdf_mass_words = set( compute_words_model_pdf_mass( words_probs, current_word_state_position, known_words_original, known_words_original_decoded, ) ) if current_word not in pdf_mass_words: pdf_mass_words.add(current_word) word_score = compute_word_logprob( model, current_state, current_word, token_indices_dict ) ppl += word_score ppl_lower += word_score - compute_denominator( model, current_state, pdf_mass_words, token_indices_dict ) else: unk_n_words += 1 current_word = "" current_state += addition_state addition_state = [] current_word_state_position += 1 # process eos word_score = compute_word_logprob(model, current_state, EOS, token_indices_dict) n_words += 1 ppl += word_score pdf_mass_words = set( compute_words_model_pdf_mass( words_probs, current_word_state_position, known_words_original, known_words_original_decoded, ) ) if EOS not in pdf_mass_words: pdf_mass_words.add(EOS) ppl_lower += word_score - compute_denominator( model, current_state, pdf_mass_words, token_indices_dict ) if n % 10 == 0: print( "Evaluated", n, "sentences among", n_logging, "upper limit perplexity", numpy.exp(-ppl / n_words), "lower limit perplexity", numpy.exp(-ppl_lower / n_words), "number of words", n_words, flush=True, ) print("Final loss", ppl, "loss lower", ppl_lower) print("Upper limit on perplexity:", numpy.exp(-ppl / n_words)) print("Lower limit on perplexity:", numpy.exp(-ppl_lower / n_words)) print("Total number of words:", n_words, "unknown words:", unk_n_words) if __name__ == "__main__": parser = argparse.ArgumentParser( description="Upper and lower limits on word perplexity for convlm models" ) parser.add_argument("--model", help="path to convlm model") parser.add_argument("--word_model", help="path to convlm model") parser.add_argument("--dict", help="path to convlm dict file in data") parser.add_argument( "--word_dict", help="path to convlm word convlm dict file in data" ) parser.add_argument( "--text", help="file to evaluate, in necessary format for model" ) parser.add_argument("--model_type", help='"char14B" or "char20B"') parser.add_argument("--dataset_type", help='"ls" or "wsj"', default="ls") args = parser.parse_args() print("Evaluate file {}".format(args.text)) token_indices_dict, indices_token_dict = build_token_index_correspondence(args.dict) word_indices_dict, indices_word_dict = build_token_index_correspondence( args.word_dict ) known_words, known_words_original = prepare_vocabs_convlm(args.word_dict) known_words_original_decoded = numpy.array( [ UNKIDX if w not in word_indices_dict else word_indices_dict[w] for w in known_words_original ] ) with open(args.text, "r") as f: sentences = [line.strip() for line in f] input_data = decodeInputText(sentences, token_indices_dict) fairseq_dict = Dictionary.load(args.dict) word_fairseq_dict = Dictionary.load(args.word_dict) word_model = load_word_model(args.word_model, word_fairseq_dict, args.dataset_type) if "14B" in args.model_type: char_model = load_char_model_14B(args.model, fairseq_dict, args.dataset_type) else: char_model = load_char_model_20B(args.model, fairseq_dict, args.dataset_type) compute_ppl_lower_limit( char_model, word_model, input_data, known_words, known_words_original, known_words_original_decoded, indices_token_dict, token_indices_dict, )