in src/sagemaker/serve/builder/jumpstart_builder.py [0:0]
def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
"""Placeholder docstring"""
env = {}
if "mode" in kwargs and kwargs.get("mode") != self.mode:
overwrite_mode = kwargs.get("mode")
# mode overwritten by customer during model.deploy()
logger.warning(
"Deploying in %s Mode, overriding existing configurations set for %s mode",
overwrite_mode,
self.mode,
)
if overwrite_mode == Mode.SAGEMAKER_ENDPOINT:
self.mode = self.pysdk_model.mode = Mode.SAGEMAKER_ENDPOINT
if (
not hasattr(self, "prepared_for_djl")
or not hasattr(self, "prepared_for_tgi")
or not hasattr(self, "prepared_for_mms")
):
if not _is_optimized(self.pysdk_model):
self.pysdk_model.model_data, env = self._prepare_for_mode()
elif overwrite_mode == Mode.LOCAL_CONTAINER:
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
if not hasattr(self, "prepared_for_djl"):
(
self.existing_properties,
self.js_model_config,
self.prepared_for_djl,
) = prepare_djl_js_resources(
model_path=self.model_path,
js_id=self.model,
dependencies=self.dependencies,
model_data=self.pysdk_model.model_data,
)
elif not hasattr(self, "prepared_for_tgi"):
self.js_model_config, self.prepared_for_tgi = prepare_tgi_js_resources(
model_path=self.model_path,
js_id=self.model,
dependencies=self.dependencies,
model_data=self.pysdk_model.model_data,
)
elif not hasattr(self, "prepared_for_mms"):
self.js_model_config, self.prepared_for_mms = prepare_mms_js_resources(
model_path=self.model_path,
js_id=self.model,
dependencies=self.dependencies,
model_data=self.pysdk_model.model_data,
)
self._prepare_for_mode()
else:
raise ValueError("Mode %s is not supported!" % overwrite_mode)
self.pysdk_model.env.update(env)
serializer = self.schema_builder.input_serializer
deserializer = self.schema_builder._output_deserializer
if self.mode == Mode.LOCAL_CONTAINER:
if self.model_server == ModelServer.DJL_SERVING:
predictor = DjlLocalModePredictor(
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
)
elif self.model_server == ModelServer.TGI:
predictor = TgiLocalModePredictor(
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
)
elif self.model_server == ModelServer.MMS:
predictor = TransformersLocalModePredictor(
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
)
ram_usage_before = _get_ram_usage_mb()
self.modes[str(Mode.LOCAL_CONTAINER)].create_server(
self.image_uri,
600,
None,
predictor,
self.pysdk_model.env,
jumpstart=True,
)
ram_usage_after = _get_ram_usage_mb()
self.ram_usage_model_load = max(ram_usage_after - ram_usage_before, 0)
return predictor
if "endpoint_logging" not in kwargs:
kwargs["endpoint_logging"] = True
if hasattr(self, "nb_instance_type"):
kwargs.update({"instance_type": self.nb_instance_type})
if "mode" in kwargs:
del kwargs["mode"]
if "role" in kwargs:
self.pysdk_model.role = kwargs.get("role")
del kwargs["role"]
predictor = self._original_deploy(*args, **kwargs)
predictor.serializer = serializer
predictor.deserializer = deserializer
return predictor