def lambda_handler()

in lambda/api/lambda_register.py [0:0]


def lambda_handler(event, context):
    try:
        logger.debug(json.dumps(event))

        if not (
            event.get("source") == "aws.sagemaker"
            and event.get("detail-type") == "SageMaker Endpoint State Change"
        ):
            raise Exception(
                "Expect CloudWatch Event for SageMaker Endpoint Stage Change"
            )

        # If this endpoint does not match prefix or not enabled return Not Modified (304)
        endpoint_name = event["detail"]["EndpointName"]
        endpoint_tags = event["detail"]["Tags"]
        endpoint_enabled = endpoint_tags.get("ab-testing:enabled", "").lower() == "true"
        if not (endpoint_name.startswith(ENDPOINT_PREFIX) and endpoint_enabled):
            error_message = (
                f"Endpoint: {endpoint_name} not enabled for prefix: {ENDPOINT_PREFIX}"
            )
            logger.warning(error_message)
            return {"statusCode": 304, "body": error_message}

        # If the API stage name doesn't match the deployment stage name return Not Modified (304)
        deployment_stage = endpoint_tags.get("sagemaker:deployment-stage")
        if deployment_stage != STAGE_NAME:
            error_message = f"Endpoint: {endpoint_name} deployment stage: {deployment_stage} not equal to API stage: {STAGE_NAME}"
            logger.warning(error_message)
            return {"statusCode": 304, "body": error_message}

        # Delete or register the endpoint depending on status change
        endpoint_status = event["detail"]["EndpointStatus"]
        if endpoint_status == "DELETING":
            logger.info(f"Deleting Endpoint: {endpoint_name}")
            result, status_code = handle_delete(endpoint_name)
        elif endpoint_status == "IN_SERVICE":
            # Use defaults if enabled is provided without additional arguments
            strategy = endpoint_tags.get("ab-testing:strategy", "ThompsonSampling")
            epsilon = float(endpoint_tags.get("ab-testing:epsilon", 0.1))
            warmup = int(endpoint_tags.get("ab-testing:warmup", 0))
            logger.info(
                f"Registering Endpoint: {endpoint_name} with strategy: {strategy}, epsilon: {epsilon}, warmup: {warmup}"
            )
            result, status_code = handle_register(
                endpoint_name, strategy, epsilon, warmup
            )
        else:
            error_message = (
                f"Endpoint: {endpoint_name} Status: {endpoint_status} not supported."
            )
            logger.warning(error_message)
            result = {"message": error_message}
            status_code = 400

        # Log result succesful result and return
        logger.debug(json.dumps(result))
        return {"statusCode": status_code, "body": json.dumps(result)}
    except ClientError as e:
        logger.error(e)
        # Get boto3 specific error message
        error_message = e.response["Error"]["Message"]
        logger.error(error_message)
        raise Exception(error_message)
    except Exception as e:
        logger.error(e)
        raise e