in source/lambda/model-invocation/index.py [0:0]
def get_fraud_prediction(data, threshold=0.5):
sagemaker_endpoint_name = "{}-xgb".format(SOLUTION_PREFIX)
sagemaker_runtime = boto3.client('sagemaker-runtime')
response = sagemaker_runtime.invoke_endpoint(
EndpointName=sagemaker_endpoint_name, ContentType='text/csv',Body=data)
pred_proba = json.loads(response['Body'].read().decode())
prediction = 0 if pred_proba < threshold else 1
logger.info("classification pred_proba: {}, prediction: {}".format(pred_proba, prediction))
return {"pred_proba": pred_proba, "prediction": prediction}