def default_model_fn()

in src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py [0:0]


    def default_model_fn(self, model_dir):
        """Loads a model. For PyTorch, a default function to load a model only if Elastic Inference is used.
        In other cases, users should provide customized model_fn() in script.

        Args:
            model_dir: a directory where model is saved.

        Returns: A PyTorch model.
        """
        if os.getenv(INFERENCE_ACCELERATOR_PRESENT_ENV) == "true":
            model_path = os.path.join(model_dir, DEFAULT_MODEL_FILENAME)
            if not os.path.exists(model_path):
                raise FileNotFoundError("Failed to load model with default model_fn: missing file {}."
                                        .format(DEFAULT_MODEL_FILENAME))
            # Client-framework is CPU only. But model will run in Elastic Inference server with CUDA.
            try:
                return torch.jit.load(model_path, map_location=torch.device('cpu'))
            except RuntimeError as e:
                raise ModelLoadError(
                    "Failed to load {}. Please ensure model is saved using torchscript.".format(model_path)
                ) from e
        else:
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            model_path = os.path.join(model_dir, DEFAULT_MODEL_FILENAME)
            if not os.path.exists(model_path):
                model_files = [file for file in os.listdir(model_dir) if self._is_model_file(file)]
                if len(model_files) != 1:
                    raise ValueError(
                        "Exactly one .pth or .pt file is required for PyTorch models: {}".format(model_files)
                    )
                model_path = os.path.join(model_dir, model_files[0])
            try:
                model = torch.jit.load(model_path, map_location=device)
            except RuntimeError as e:
                raise ModelLoadError(
                    "Failed to load {}. Please ensure model is saved using torchscript.".format(model_path)
                ) from e
            model = model.to(device)
            return model