in src/sagemaker/serve/save_retrive/version_1_0_0/save/save_handler.py [0:0]
def save(self) -> Type[Model]:
"""Save the model and the metadata"""
logger.info("Saving model to %s", self.save_path)
if not Path(self.save_path).exists():
Path(self.save_path).mkdir(parents=True, exist_ok=True)
inferred = detect_framework_and_its_versions(
self.model if self.model else self.inference_spec.load(self.model_loader_path)
)
self.framework = inferred[0][0]
self.framework_version = inferred[0][1]
self.py_version = inferred[1]
capture_dependencies(self.requirements_path)
self.optimizer_metadata = capture_optimization_metadata(self.model, self.framework)
handler = None
if self.framework == "pytorch":
handler = PyTorchHandler(
VERION,
self.py_version,
self.framework,
self.framework_version,
self.model,
self.model_path,
self.requirements_path,
self.schema_builder,
self.schema_path,
self.schema_format,
self.task,
self.optimizer_metadata,
self.inference_spec,
self.inference_spec_path,
self.inference_spec_format,
self.metadata_path,
)
elif self.framework == "xgboost":
handler = XGBoostHandler(
VERION,
self.py_version,
self.framework,
self.framework_version,
self.model,
self.model_path,
self.requirements_path,
self.schema_builder,
self.schema_path,
self.schema_format,
self.task,
self.optimizer_metadata,
self.inference_spec,
self.inference_spec_path,
self.inference_spec_format,
self.metadata_path,
)
else:
raise ValueError("Unknown framework type {}".format(self.framework))
# save model and the metadata
handler.save_model()
handler.save_metadata()
# upload to s3
s3_model_url = upload_to_s3(self.s3_path, self.save_path, self.sagemaker_session)
return handler.get_pysdk_model(s3_model_url, self.role_arn, self.sagemaker_session)