def predict_fn()

in src/serve.py [0:0]


def predict_fn(input, model_artifacts):
    preprocessor, model, label_mapper = model_artifacts

    # Pre-process
    input_tensor = preprocess(input, preprocessor)

    # Copy input to gpu if available
    device = get_device()
    input_tensor = input_tensor.to(device=device)

    # Invoke
    model.eval()
    with torch.no_grad():
        output_tensor = model(input_tensor)[0]
        # Convert to probabilities
        softmax = torch.nn.Softmax()
        output_tensor = softmax(output_tensor)

    # Return the class with the highest prob and the corresponding prob
    prob, class_indices = torch.max(output_tensor, dim=1)
    classes = [label_mapper.reverse_map(i.item()) for i in class_indices]
    result = []
    for c, p in zip(classes, prob):
        result.append({c: p.item()})

    return result