def word_probs()

in tseval/models/language_models.py [0:0]


def word_probs(sentence, model, dictionary, device, log_probs=True, verbose=False):
    def prepend_eos(indexes):
        return torch.cat([torch.LongTensor([[dictionary.eos()]]).to(device), indexes], dim=1)

    def append_eos(indexes):
        return torch.cat([indexes, torch.LongTensor([[dictionary.eos()]]).to(device)], dim=1)

    # Forward pass
    indexes = torch.LongTensor([[dictionary.index(word) for word in sentence.split()]]).to(device)
    src_lengths = torch.LongTensor([indexes.shape[1] + 1]).to(device)  # src_lengths = [len(words) + 1]
    decoder_out = model(prepend_eos(indexes), src_lengths)

    # Compute sentence log-likelihood
    logprobs = model.get_normalized_probs(decoder_out,
                                          log_probs=log_probs,
                                          sample={'target': append_eos(indexes)}).squeeze()
    if verbose:
        # Print top predictions
        # We must feed None as target to get_normalized_probs
        logprobs = model.get_normalized_probs(decoder_out, log_probs=log_probs, sample={'target': None}).squeeze()
        for i in range(indexes.shape[1] + 1):
            print(i)
            sorted_values, sorted_indexes = logprobs[i].sort(descending=True)
            top_predictions = [f'{dictionary[sorted_indexes[j]]} ({sorted_values[j]:.1f})' for j in range(5)]
            input_words = ' '.join([dictionary[idx] for idx in prepend_eos(indexes).squeeze()[:i+1]])
            print(input_words, '->', '; '.join(top_predictions))

    return logprobs.gather(dim=1, index=append_eos(indexes).reshape(-1, 1)).squeeze().data.cpu().numpy()