def predict_fn()

in models/model-a/src/code/inference.py [0:0]


def predict_fn(sentence, model_dict):
    logger.info('predict_fn: Predicting for {}.'.format(sentence))
    
    model = model_dict['model']
    dictionary = model_dict['dictionary']

    with torch.no_grad():
        sentence_tensor = torch.tensor([dictionary[token]
                            for token in ngrams_iterator(_tokenizer(sentence), _ngrams)])
        output = model(sentence_tensor, torch.tensor([0]))
        label = output.argmax(1).item() + 1
        logger.info('predict_fn: Prediction result is {}.'.format(label))
        return label