def compute_ppl_upper_limit_word_convlm()

in recipes/lexicon_free/utilities/compute_upper_ppl_convlm.py [0:0]


def compute_ppl_upper_limit_word_convlm(model, input_wordlm):
    sum_logp_known = 0
    n_known_words = 0
    sum_logp_unknown = 0
    n_unknown_words = 0

    for sentence in input_wordlm:
        x = torch.LongTensor([EOSIDX] + sentence).reshape(1, len(sentence) + 1).cuda()
        with torch.no_grad():
            y = model.forward(x)[0]
            logprobs = (
                model.adaptive_softmax.get_log_prob(y, None).detach().cpu().numpy()[0]
            )

        for index, token_id in enumerate(sentence):
            if token_id != UNKIDX:
                sum_logp_known += logprobs[index, token_id]
                n_known_words += 1
            else:
                sum_logp_unknown += logprobs[index, token_id]
                n_unknown_words += 1

        # add eos
        sum_logp_known += logprobs[-1, EOSIDX]
        n_known_words += 1

    ppl_word_no_unk = numpy.exp(-sum_logp_known / n_known_words)
    ppl_word_unk = numpy.exp(-sum_logp_unknown / n_unknown_words)
    ppl_word = numpy.exp(
        -(sum_logp_known + sum_logp_unknown) / (n_known_words + n_unknown_words)
    )

    print("Word perplexity for all words: {}".format(ppl_word))
    print("Word perplexity for unknown words: {}".format(ppl_word_unk))
    print(
        "(Reported in the paper) "
        "Word perplexity for known words: {}".format(ppl_word_no_unk)
    )