in src/sagemaker/workflow/_utils.py [0:0]
def arguments(self) -> RequestType:
"""The arguments dict that are used to call `create_model_package`."""
from sagemaker.workflow.utilities import _pipeline_config
model_name = self.name
if self.step_args:
request_dict = self.step_args
else:
if self.container_def_list is None:
if self.compile_model_family:
model = self.estimator._compiled_models[self.compile_model_family]
self.model_data = model.model_data
else:
# create_model wants the estimator to have a model_data attribute...
self.estimator._current_job_name = model_name
# placeholder. replaced with model_data later
output_path = self.estimator.output_path
self.estimator.output_path = "/tmp"
# create the model, but custom funky framework stuff going on in some places
if self.image_uri:
model = self.estimator.create_model(image_uri=self.image_uri, **self.kwargs)
else:
model = self.estimator.create_model(**self.kwargs)
self.image_uri = model.image_uri
# reset placeholder
self.estimator.output_path = output_path
# yeah, there is some framework stuff going on that we need to pull in here
if self.image_uri is None:
region_name = self.estimator.sagemaker_session.boto_session.region_name
self.image_uri = image_uris.retrieve(
model._framework_name,
region_name,
version=model.framework_version,
py_version=model.py_version if hasattr(model, "py_version") else None,
instance_type=self.kwargs.get(
"instance_type", self.estimator.instance_type
),
accelerator_type=self.kwargs.get("accelerator_type"),
image_scope="inference",
)
model.name = model_name
model.model_data = self.model_data
model_package_args = get_model_package_args(
content_types=self.content_types,
response_types=self.response_types,
inference_instances=self.inference_instances,
transform_instances=self.transform_instances,
model_package_group_name=self.model_package_group_name,
model_data=self.model_data,
image_uri=self.image_uri,
model_metrics=self.model_metrics,
drift_check_baselines=self.drift_check_baselines,
metadata_properties=self.metadata_properties,
approval_status=self.approval_status,
description=self.description,
tags=self.tags,
container_def_list=self.container_def_list,
customer_metadata_properties=self.customer_metadata_properties,
domain=self.domain,
sample_payload_url=self.sample_payload_url,
task=self.task,
skip_model_validation=self.skip_model_validation,
source_uri=self.source_uri,
model_card=self.model_card,
model_life_cycle=self.model_life_cycle,
)
request_dict = get_create_model_package_request(**model_package_args)
# these are not available in the workflow service and will cause rejection
warn_msg_template = (
"Popping out '%s' from the pipeline definition "
"since it will be overridden in pipeline execution time."
)
if "CertifyForMarketplace" in request_dict:
request_dict.pop("CertifyForMarketplace")
logger.warning(warn_msg_template, "CertifyForMarketplace")
if "Description" in request_dict:
request_dict.pop("Description")
logger.warning(warn_msg_template, "Description")
# Continue to pop job name if not explicitly opted-in via config
request_dict = trim_request_dict(request_dict, "ModelPackageName", _pipeline_config)
return request_dict