in src/sagemaker_huggingface_inference_toolkit/mms_model_server.py [0:0]
def start_model_server(handler_service=DEFAULT_HANDLER_SERVICE):
"""Configure and start the model server.
Args:
handler_service (str): python path pointing to a module that defines
a class with the following:
- A ``handle`` method, which is invoked for all incoming inference
requests to the model server.
- A ``initialize`` method, which is invoked at model server start up
for loading the model.
Defaults to ``sagemaker_huggingface_inference_toolkit.handler_service``.
"""
use_hf_hub = "HF_MODEL_ID" in os.environ
model_store = DEFAULT_MODEL_STORE
if ENABLE_MULTI_MODEL:
if not os.getenv("SAGEMAKER_HANDLER"):
os.environ["SAGEMAKER_HANDLER"] = handler_service
_set_python_path()
elif use_hf_hub:
# Use different model store directory
model_store = DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY
storage_dir = _load_model_from_hub(
model_id=os.environ["HF_MODEL_ID"],
model_dir=DEFAULT_HF_HUB_MODEL_EXPORT_DIRECTORY,
revision=HF_MODEL_REVISION,
use_auth_token=HF_API_TOKEN,
)
_adapt_to_mms_format(handler_service, storage_dir)
else:
_set_python_path()
env = environment.Environment()
# Set the number of workers to available number if optimum neuron is available and not already set
if is_optimum_neuron_available() and os.environ.get("SAGEMAKER_MODEL_SERVER_WORKERS", None) is None:
from optimum.neuron.utils.cache_utils import get_num_neuron_cores
try:
env._model_server_workers = str(get_num_neuron_cores())
except Exception:
env._model_server_workers = "1"
# Note: multi-model default config already sets default_service_handler
handler_service_for_config = None if ENABLE_MULTI_MODEL else handler_service
_create_model_server_config_file(env, handler_service_for_config)
if os.path.exists(REQUIREMENTS_PATH):
_install_requirements()
multi_model_server_cmd = [
"multi-model-server",
"--start",
"--model-store",
model_store,
"--mms-config",
MMS_CONFIG_FILE,
"--log-config",
DEFAULT_MMS_LOG_FILE,
]
if not ENABLE_MULTI_MODEL and not use_hf_hub:
multi_model_server_cmd += ["--models", DEFAULT_MMS_MODEL_NAME + "=" + environment.model_dir]
logger.info(multi_model_server_cmd)
subprocess.Popen(multi_model_server_cmd)
# retry for configured timeout
mms_process = _retry_retrieve_mms_server_process(env.startup_timeout)
_add_sigterm_handler(mms_process)
_add_sigchild_handler()
mms_process.wait()