in src/sagemaker_mxnet_serving_container/default_inference_handler.py [0:0]
def default_model_fn(self, model_dir, preferred_batch_size=1):
"""Function responsible for loading the model. This implementation is designed to work with
the default save function provided for MXNet training.
Args:
model_dir (str): The directory where model files are stored
preferred_batch_size (int): preferred batch size of the model's data shape.
Defaults to 1.
Returns:
mxnet.mod.Module: the loaded model.
"""
for f in DEFAULT_MODEL_FILENAMES.values():
path = os.path.join(model_dir, f)
if not os.path.exists(path):
raise ValueError('Failed to load model with default model_fn: missing file {}.'
'Expected files: {}'.format(f, [file_name for _, file_name
in DEFAULT_MODEL_FILENAMES.items()]))
shapes_file = os.path.join(model_dir, DEFAULT_MODEL_FILENAMES['shapes'])
preferred_batch_size = preferred_batch_size or os.environ.get(PREFERRED_BATCH_SIZE_PARAM)
data_names, data_shapes = read_data_shapes(shapes_file, preferred_batch_size)
sym, args, aux = mx.model.load_checkpoint(os.path.join(model_dir, DEFAULT_MODEL_NAME), 0)
ctx = mx.eia() if os.environ.get(INFERENCE_ACCELERATOR_PRESENT_ENV) == 'true' else get_default_context()
mod = mx.mod.Module(symbol=sym, context=ctx, data_names=data_names, label_names=None)
mod.bind(for_training=False, data_shapes=data_shapes)
mod.set_params(args, aux, allow_missing=True)
return mod