def default_predict_fn()

in src/sagemaker_pytorch_serving_container/default_pytorch_inference_handler.py [0:0]


    def default_predict_fn(self, data, model):
        """A default predict_fn for PyTorch. Calls a model on data deserialized in input_fn.
        Runs prediction on GPU if cuda is available.

        Args:
            data: input data (torch.Tensor) for prediction deserialized by input_fn
            model: PyTorch model loaded in memory by model_fn

        Returns: a prediction
        """
        with torch.no_grad():
            if os.getenv(INFERENCE_ACCELERATOR_PRESENT_ENV) == "true":
                device = torch.device("cpu")
                model = model.to(device)
                input_data = data.to(device)
                model.eval()
                with torch.jit.optimized_execution(True, {"target_device": "eia:0"}):
                    output = model(input_data)
            else:
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                model = model.to(device)
                input_data = data.to(device)
                model.eval()
                output = model(input_data)

        return output