in src/sagemaker/workflow/_utils.py [0:0]
def arguments(self) -> RequestType:
"""The arguments dict that are used to call `create_model_package`."""
model_name = self.name
if self.container_def_list is None:
if self.compile_model_family:
model = self.estimator._compiled_models[self.compile_model_family]
self.model_data = model.model_data
else:
# create_model wants the estimator to have a model_data attribute...
self.estimator._current_job_name = model_name
# placeholder. replaced with model_data later
output_path = self.estimator.output_path
self.estimator.output_path = "/tmp"
# create the model, but custom funky framework stuff going on in some places
if self.image_uri:
model = self.estimator.create_model(image_uri=self.image_uri, **self.kwargs)
else:
model = self.estimator.create_model(**self.kwargs)
self.image_uri = model.image_uri
if self.model_data is None:
self.model_data = model.model_data
# reset placeholder
self.estimator.output_path = output_path
# yeah, there is some framework stuff going on that we need to pull in here
if self.image_uri is None:
region_name = self.estimator.sagemaker_session.boto_session.region_name
self.image_uri = image_uris.retrieve(
model._framework_name,
region_name,
version=model.framework_version,
py_version=model.py_version if hasattr(model, "py_version") else None,
instance_type=self.kwargs.get(
"instance_type", self.estimator.instance_type
),
accelerator_type=self.kwargs.get("accelerator_type"),
image_scope="inference",
)
model.name = model_name
model.model_data = self.model_data
model_package_args = get_model_package_args(
content_types=self.content_types,
response_types=self.response_types,
inference_instances=self.inference_instances,
transform_instances=self.transform_instances,
model_package_group_name=self.model_package_group_name,
model_data=self.model_data,
image_uri=self.image_uri,
model_metrics=self.model_metrics,
drift_check_baselines=self.drift_check_baselines,
metadata_properties=self.metadata_properties,
approval_status=self.approval_status,
description=self.description,
tags=self.tags,
container_def_list=self.container_def_list,
)
request_dict = get_create_model_package_request(**model_package_args)
# these are not available in the workflow service and will cause rejection
if "CertifyForMarketplace" in request_dict:
request_dict.pop("CertifyForMarketplace")
if "Description" in request_dict:
request_dict.pop("Description")
return request_dict