def get_fraud_prediction()

in source/lambda/model-invocation/index.py [0:0]


def get_fraud_prediction(data, threshold=0.5):
    sagemaker_endpoint_name = "{}-xgb".format(SOLUTION_PREFIX)
    sagemaker_runtime = boto3.client('sagemaker-runtime')
    response = sagemaker_runtime.invoke_endpoint(
        EndpointName=sagemaker_endpoint_name, ContentType='text/csv',Body=data)
    pred_proba = json.loads(response['Body'].read().decode())
    prediction = 0 if pred_proba < threshold else 1

    logger.info("classification pred_proba: {}, prediction: {}".format(pred_proba, prediction))

    return {"pred_proba": pred_proba, "prediction": prediction}