in lambda/api/lambda_register.py [0:0]
def lambda_handler(event, context):
try:
logger.debug(json.dumps(event))
if not (
event.get("source") == "aws.sagemaker"
and event.get("detail-type") == "SageMaker Endpoint State Change"
):
raise Exception(
"Expect CloudWatch Event for SageMaker Endpoint Stage Change"
)
# If this endpoint does not match prefix or not enabled return Not Modified (304)
endpoint_name = event["detail"]["EndpointName"]
endpoint_tags = event["detail"]["Tags"]
endpoint_enabled = endpoint_tags.get("ab-testing:enabled", "").lower() == "true"
if not (endpoint_name.startswith(ENDPOINT_PREFIX) and endpoint_enabled):
error_message = (
f"Endpoint: {endpoint_name} not enabled for prefix: {ENDPOINT_PREFIX}"
)
logger.warning(error_message)
return {"statusCode": 304, "body": error_message}
# If the API stage name doesn't match the deployment stage name return Not Modified (304)
deployment_stage = endpoint_tags.get("sagemaker:deployment-stage")
if deployment_stage != STAGE_NAME:
error_message = f"Endpoint: {endpoint_name} deployment stage: {deployment_stage} not equal to API stage: {STAGE_NAME}"
logger.warning(error_message)
return {"statusCode": 304, "body": error_message}
# Delete or register the endpoint depending on status change
endpoint_status = event["detail"]["EndpointStatus"]
if endpoint_status == "DELETING":
logger.info(f"Deleting Endpoint: {endpoint_name}")
result, status_code = handle_delete(endpoint_name)
elif endpoint_status == "IN_SERVICE":
# Use defaults if enabled is provided without additional arguments
strategy = endpoint_tags.get("ab-testing:strategy", "ThompsonSampling")
epsilon = float(endpoint_tags.get("ab-testing:epsilon", 0.1))
warmup = int(endpoint_tags.get("ab-testing:warmup", 0))
logger.info(
f"Registering Endpoint: {endpoint_name} with strategy: {strategy}, epsilon: {epsilon}, warmup: {warmup}"
)
result, status_code = handle_register(
endpoint_name, strategy, epsilon, warmup
)
else:
error_message = (
f"Endpoint: {endpoint_name} Status: {endpoint_status} not supported."
)
logger.warning(error_message)
result = {"message": error_message}
status_code = 400
# Log result succesful result and return
logger.debug(json.dumps(result))
return {"statusCode": status_code, "body": json.dumps(result)}
except ClientError as e:
logger.error(e)
# Get boto3 specific error message
error_message = e.response["Error"]["Message"]
logger.error(error_message)
raise Exception(error_message)
except Exception as e:
logger.error(e)
raise e