def predict_fn()

in sagemaker/18_inferentia_inference/code/inference.py [0:0]


def predict_fn(data, model_tokenizer_model_config):
    # destruct model, tokenizer and model config
    model, tokenizer, model_config = model_tokenizer_model_config

    # create embeddings for inputs
    inputs = data.pop("inputs", data)
    embeddings = tokenizer(
        inputs,
        return_tensors="pt",
        max_length=model_config.traced_sequence_length,
        padding="max_length",
        truncation=True,
    )
    # convert to tuple for neuron model
    neuron_inputs = tuple(embeddings.values())

    # run prediciton
    with torch.no_grad():
        predictions = model(*neuron_inputs)[0]
        scores = torch.nn.Softmax(dim=1)(predictions)

    # return dictonary, which will be json serializable
    return [{"label": model_config.id2label[item.argmax().item()], "score": item.max().item()} for item in scores]