in src/sagemaker/serve/builder/tgi_builder.py [0:0]
def _tgi_model_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
"""Placeholder docstring"""
timeout = kwargs.get("model_data_download_timeout")
if timeout:
self.pysdk_model.env.update({"MODEL_LOADING_TIMEOUT": str(timeout)})
if "mode" in kwargs and kwargs.get("mode") != self.mode:
overwrite_mode = kwargs.get("mode")
# mode overwritten by customer during model.deploy()
logger.warning(
"Deploying in %s Mode, overriding existing configurations set for %s mode",
overwrite_mode,
self.mode,
)
if overwrite_mode == Mode.SAGEMAKER_ENDPOINT:
self.mode = self.pysdk_model.mode = Mode.SAGEMAKER_ENDPOINT
elif overwrite_mode == Mode.LOCAL_CONTAINER:
self._prepare_for_mode()
self.mode = self.pysdk_model.mode = Mode.LOCAL_CONTAINER
else:
raise ValueError("Mode %s is not supported!" % overwrite_mode)
serializer = self.schema_builder.input_serializer
deserializer = self.schema_builder._output_deserializer
if self.mode == Mode.IN_PROCESS:
predictor = InProcessModePredictor(
self.modes[str(Mode.IN_PROCESS)], serializer, deserializer
)
self.modes[str(Mode.IN_PROCESS)].create_server(
predictor,
)
return predictor
if self.mode == Mode.LOCAL_CONTAINER:
timeout = kwargs.get("model_data_download_timeout")
predictor = TgiLocalModePredictor(
self.modes[str(Mode.LOCAL_CONTAINER)], serializer, deserializer
)
ram_usage_before = _get_ram_usage_mb()
self.modes[str(Mode.LOCAL_CONTAINER)].create_server(
self.image_uri,
timeout if timeout else 1800,
None,
predictor,
self.pysdk_model.env,
jumpstart=False,
)
ram_usage_after = _get_ram_usage_mb()
self.ram_usage_model_load = max(ram_usage_after - ram_usage_before, 0)
return predictor
if "mode" in kwargs:
del kwargs["mode"]
if "role" in kwargs:
self.pysdk_model.role = kwargs.get("role")
del kwargs["role"]
if not _is_optimized(self.pysdk_model):
env_vars = {}
if str(Mode.LOCAL_CONTAINER) in self.modes:
# upload model artifacts to S3 if LOCAL_CONTAINER -> SAGEMAKER_ENDPOINT
self.pysdk_model.model_data, env_vars = self._prepare_for_mode(
model_path=self.model_path, should_upload_artifacts=True
)
else:
_, env_vars = self._prepare_for_mode()
self.env_vars.update(env_vars)
self.pysdk_model.env.update(self.env_vars)
# if the weights have been cached via local container mode -> set to offline
if str(Mode.LOCAL_CONTAINER) in self.modes:
self.pysdk_model.env.update({"HF_HUB_OFFLINE": "1"})
else:
# if has not been built for local container we must use cache
# that hosting has write access to.
self.pysdk_model.env["HF_HOME"] = "/tmp"
self.pysdk_model.env["HUGGINGFACE_HUB_CACHE"] = "/tmp"
if "endpoint_logging" not in kwargs:
kwargs["endpoint_logging"] = True
if self.nb_instance_type and "instance_type" not in kwargs:
kwargs.update({"instance_type": self.nb_instance_type})
elif not self.nb_instance_type and "instance_type" not in kwargs:
raise ValueError(
"Instance type must be provided when deploying " "to SageMaker Endpoint mode."
)
else:
try:
tot_gpus = _get_gpu_info(kwargs.get("instance_type"), self.sagemaker_session)
except Exception: # pylint: disable=W0703
tot_gpus = _get_gpu_info_fallback(kwargs.get("instance_type"))
default_num_shard = _get_default_tensor_parallel_degree(self.hf_model_config, tot_gpus)
self.pysdk_model.env.update(
{
"NUM_SHARD": str(default_num_shard),
"SHARDED": "true" if default_num_shard > 1 else "false",
}
)
if "initial_instance_count" not in kwargs:
kwargs.update({"initial_instance_count": 1})
predictor = self._original_deploy(*args, **kwargs)
if "HF_HUB_OFFLINE" in self.pysdk_model.env:
self.pysdk_model.env.update({"HF_HUB_OFFLINE": "0"})
predictor.serializer = serializer
predictor.deserializer = deserializer
return predictor