in src/sagemaker_huggingface_inference_toolkit/handler_service.py [0:0]
def load(self, model_dir, context=None):
"""
The Load handler is responsible for loading the Hugging Face transformer model.
It can be overridden to load the model from storage.
Args:
model_dir (str): The directory where model files are stored.
context (obj): metadata on the incoming request data (default: None).
Returns:
hf_pipeline (Pipeline): A Hugging Face Transformer pipeline.
"""
# gets pipeline from task tag
if "HF_TASK" in os.environ:
hf_pipeline = get_pipeline(task=os.environ["HF_TASK"], model_dir=model_dir, device=self.device)
elif "config.json" in os.listdir(model_dir):
task = infer_task_from_model_architecture(f"{model_dir}/config.json")
hf_pipeline = get_pipeline(task=task, model_dir=model_dir, device=self.device)
elif "model_index.json" in os.listdir(model_dir):
task = "text-to-image"
hf_pipeline = get_pipeline(task=task, model_dir=model_dir, device=self.device)
else:
raise ValueError(
f"You need to define one of the following {list(SUPPORTED_TASKS.keys())} or text-to-image as env 'HF_TASK'.",
403,
)
return hf_pipeline