in source/lib/blueprints/byom/lambdas/create_baseline_job/main.py [0:0]
def handler(event, context):
# get some environment variables
assets_bucket = os.environ["ASSETS_BUCKET"]
monitoring_type = os.environ.get("MONITORING_TYPE")
baseline_job_name = os.environ["BASELINE_JOB_NAME"]
instance_type = os.environ.get("INSTANCE_TYPE", "ml.m5.large")
instance_count = int(os.environ.get("INSTANCE_COUNT", "1"))
max_runtime_seconds = os.environ.get("MAX_RUNTIME_SECONDS")
baseline_dataset_file_key = os.environ["BASELINE_DATA_LOCATION"]
baseline_data_s3_uri = f"s3://{assets_bucket}/{baseline_dataset_file_key}"
baseline_output_s3_uri = f"s3://{os.environ['BASELINE_JOB_OUTPUT_LOCATION']}"
endpoint_name = os.getenv("ENDPOINT_NAME")
# used only for ModelBias/Explanability
# model_predicted_label_config is optional for regression
raw_model_predicted_label_config = os.getenv("MODEL_PREDICTED_LABEL_CONFIG")
model_predicted_label_config = (
json.loads(raw_model_predicted_label_config)
if raw_model_predicted_label_config
# set default for regression problem
else dict(label=None, probability=None, probability_threshold=None, label_headers=None)
)
# bias_config required for ModelBias
bias_config = json.loads(os.getenv("BIAS_CONFIG", "{}"))
# shap_config required for ModelExplainability
shap_config = json.loads(os.getenv("SHAP_CONFIG", "{}"))
# check if the baseline is a string (a file key is provided not a list of lists to calculate the baseline)
# the baseline file is expected to be in the Assets Bucket
baseline = shap_config.get("baseline")
# add assets bucket if a file key is provided
if isinstance(baseline, str):
shap_config["baseline"] = f"s3://{assets_bucket}/{baseline}"
# use model scores if provided
model_scores = json.loads(os.getenv("MODEL_SCORES")) if os.getenv("MODEL_SCORES") else None
logger.info(f"Creating {monitoring_type} baseline processing job {baseline_job_name} ...")
# get config file contents if the baseline to be created is ModelBias|ModelExplainability
# the config file should be uploaded to the Solution's Assets S3 bucket
# details on the contents for expected ModelBias|ModelExplainability config files are provided in the
# SolutionSageMakerBaselines.get_baseline_config_file function's docs
header = None
if monitoring_type in ["ModelBias", "ModelExplainability"]:
header = SolutionSageMakerBaselines.get_baseline_dataset_header(
bucket_name=assets_bucket, file_key=baseline_dataset_file_key, s3_client=s3_client
)
# create a SageMakerBaselines instance
sagemaker_baseline = SolutionSageMakerBaselines(
monitoring_type=os.environ.get("MONITORING_TYPE"),
instance_type=instance_type,
instance_count=instance_count,
instance_volume_size=int(os.environ.get("INSTANCE_VOLUME_SIZE", "30")),
role_arn=os.environ["ROLE_ARN"],
baseline_job_name=os.environ["BASELINE_JOB_NAME"],
baseline_dataset=baseline_data_s3_uri,
output_s3_uri=baseline_output_s3_uri,
max_runtime_in_seconds=int(max_runtime_seconds) if max_runtime_seconds else None,
kms_key_arn=os.environ.get("KMS_KEY_ARN"),
problem_type=os.environ.get("PROBLEM_TYPE"),
ground_truth_attribute=os.environ.get("GROUND_TRUTH_ATTRIBUTE"),
inference_attribute=os.environ.get("INFERENCE_ATTRIBUTE"),
probability_attribute=os.environ.get("PROBABILITY_ATTRIBUTE"),
probability_threshold_attribute=os.environ.get("PROBABILITY_THRESHOLD_ATTRIBUTE"),
sagemaker_session=sagemaker_session,
data_config=DataConfig(
s3_data_input_path=baseline_data_s3_uri,
s3_output_path=baseline_output_s3_uri,
label=header[0], # the label is expected to be the first column in the baseline dataset
headers=header,
dataset_type="text/csv",
)
if monitoring_type in ["ModelBias", "ModelExplainability"]
else None,
bias_config=BiasConfig(**bias_config) if monitoring_type == "ModelBias" else None,
model_config=ModelConfig(
model_name=SolutionSageMakerBaselines.get_model_name(endpoint_name, sm_client),
instance_type=instance_type,
instance_count=instance_count,
accept_type="text/csv",
)
if monitoring_type in ["ModelBias", "ModelExplainability"]
else None,
model_predicted_label_config=ModelPredictedLabelConfig(**model_predicted_label_config)
if monitoring_type == "ModelBias"
else None,
explainability_config=SHAPConfig(**shap_config) if monitoring_type == "ModelExplainability" else None,
model_scores=model_scores if monitoring_type == "ModelExplainability" else None,
tags=[{"Key": "stack_name", "Value": os.environ["STACK_NAME"]}],
)
# create the SageMaker Baseline Job
baseline_job = sagemaker_baseline.create_baseline_job()
logger.info(f"Started {monitoring_type} baseline processing job. Job info: {baseline_job.describe()}")