def default_model_fn()

in src/sagemaker_mxnet_serving_container/default_inference_handler.py [0:0]


    def default_model_fn(self, model_dir, preferred_batch_size=1):
        """Function responsible for loading the model. This implementation is designed to work with
        the default save function provided for MXNet training.

        Args:
            model_dir (str): The directory where model files are stored
            preferred_batch_size (int): preferred batch size of the model's data shape.
                Defaults to 1.

        Returns:
            mxnet.mod.Module: the loaded model.

        """
        for f in DEFAULT_MODEL_FILENAMES.values():
            path = os.path.join(model_dir, f)
            if not os.path.exists(path):
                raise ValueError('Failed to load model with default model_fn: missing file {}.'
                                 'Expected files: {}'.format(f, [file_name for _, file_name
                                                                 in DEFAULT_MODEL_FILENAMES.items()]))

        shapes_file = os.path.join(model_dir, DEFAULT_MODEL_FILENAMES['shapes'])
        preferred_batch_size = preferred_batch_size or os.environ.get(PREFERRED_BATCH_SIZE_PARAM)
        data_names, data_shapes = read_data_shapes(shapes_file, preferred_batch_size)

        sym, args, aux = mx.model.load_checkpoint(os.path.join(model_dir, DEFAULT_MODEL_NAME), 0)

        ctx = mx.eia() if os.environ.get(INFERENCE_ACCELERATOR_PRESENT_ENV) == 'true' else get_default_context()

        mod = mx.mod.Module(symbol=sym, context=ctx, data_names=data_names, label_names=None)
        mod.bind(for_training=False, data_shapes=data_shapes)
        mod.set_params(args, aux, allow_missing=True)

        return mod