in src/sagemaker/workflow/check_job_config.py [0:0]
def _generate_model_monitor(self, mm_type: str) -> Optional[ModelMonitor]:
"""Generates a ModelMonitor object
Generates a ModelMonitor object with required config attributes for
QualityCheckStep and ClarifyCheckStep
Args:
mm_type (str): The subclass type of ModelMonitor object.
A valid mm_type should be one of the following: "DefaultModelMonitor",
"ModelQualityMonitor", "ModelBiasMonitor", "ModelExplainabilityMonitor"
Return:
sagemaker.model_monitor.ModelMonitor or None if the mm_type is not valid
"""
if mm_type == "DefaultModelMonitor":
monitor = DefaultModelMonitor(
role=self.role,
instance_count=self.instance_count,
instance_type=self.instance_type,
volume_size_in_gb=self.volume_size_in_gb,
volume_kms_key=self.volume_kms_key,
output_kms_key=self.output_kms_key,
max_runtime_in_seconds=self.max_runtime_in_seconds,
base_job_name=self.base_job_name,
sagemaker_session=self.sagemaker_session,
env=self.env,
tags=self.tags,
network_config=self.network_config,
)
elif mm_type == "ModelQualityMonitor":
monitor = ModelQualityMonitor(
role=self.role,
instance_count=self.instance_count,
instance_type=self.instance_type,
volume_size_in_gb=self.volume_size_in_gb,
volume_kms_key=self.volume_kms_key,
output_kms_key=self.output_kms_key,
max_runtime_in_seconds=self.max_runtime_in_seconds,
base_job_name=self.base_job_name,
sagemaker_session=self.sagemaker_session,
env=self.env,
tags=self.tags,
network_config=self.network_config,
)
elif mm_type == "ModelBiasMonitor":
monitor = ModelBiasMonitor(
role=self.role,
instance_count=self.instance_count,
instance_type=self.instance_type,
volume_size_in_gb=self.volume_size_in_gb,
volume_kms_key=self.volume_kms_key,
output_kms_key=self.output_kms_key,
max_runtime_in_seconds=self.max_runtime_in_seconds,
base_job_name=self.base_job_name,
sagemaker_session=self.sagemaker_session,
env=self.env,
tags=self.tags,
network_config=self.network_config,
)
elif mm_type == "ModelExplainabilityMonitor":
monitor = ModelExplainabilityMonitor(
role=self.role,
instance_count=self.instance_count,
instance_type=self.instance_type,
volume_size_in_gb=self.volume_size_in_gb,
volume_kms_key=self.volume_kms_key,
output_kms_key=self.output_kms_key,
max_runtime_in_seconds=self.max_runtime_in_seconds,
base_job_name=self.base_job_name,
sagemaker_session=self.sagemaker_session,
env=self.env,
tags=self.tags,
network_config=self.network_config,
)
else:
logging.warning(
'Expected model monitor types: "DefaultModelMonitor", "ModelQualityMonitor", '
'"ModelBiasMonitor", "ModelExplainabilityMonitor"'
)
return None
return monitor