def load()

in src/sagemaker_huggingface_inference_toolkit/handler_service.py [0:0]


    def load(self, model_dir):
        """
        The Load handler is responsible for loading the Hugging Face transformer model.
        It can be overridden to load the model from storage
        Returns:
            hf_pipeline (Pipeline): A Hugging Face Transformer pipeline.
        """
        # gets pipeline from task tag
        if "HF_TASK" in os.environ:
            hf_pipeline = get_pipeline(task=os.environ["HF_TASK"], model_dir=model_dir, device=self.device)
        elif "config.json" in os.listdir(model_dir):
            task = infer_task_from_model_architecture(f"{model_dir}/config.json")
            hf_pipeline = get_pipeline(task=task, model_dir=model_dir, device=self.device)
        else:
            raise ValueError(
                f"You need to define one of the following {list(SUPPORTED_TASKS.keys())} as env 'HF_TASK'.", 403
            )
        return hf_pipeline