in sockeye/model.py [0:0]
def load_model(model_folder: str,
context: Union[List[mx.context.Context], mx.context.Context] = mx.cpu(),
dtype: Optional[str] = None,
checkpoint: Optional[int] = None,
hybridize: bool = True,
inference_only: bool = False,
train_decoder_only: bool = False,
mc_dropout: bool = False,
for_disk_saving: Optional[str] = None,
allow_missing: bool = False,
set_grad_req_null: bool = True,
forward_pass_cache_size: int = 0) -> Tuple[SockeyeModel, List[vocab.Vocab], List[vocab.Vocab]]:
"""
Load a model from model_folder.
:param model_folder: Model folder.
:param context: MXNet context to bind modules to.
:param checkpoint: Checkpoint to use. If none, uses best checkpoint.
:param dtype: Optional data type to use. If None, will be inferred from stored model.
:param hybridize: Whether to hybridize the loaded models. Default: true.
:param inference_only: Use the model only for inference, enabling optimizations.
:param train_decoder_only: Training will only update the decoder. Disable
autograd for encoder and embeddings to save memory.
:param mc_dropout: Turn on dropout during inference.
:param for_disk_saving: For saving quantized models to disk.
None: load as usual and the model will work.
int8: The model loaded into RAM will not work, but is suitable for
writing to disk in quantized format (including scaling factors).
float32: The model loaded into RAM will not work, but is suitable
for writing to disk as float32 with precomputed scaling factors.
:param allow_missing: Allow missing parameters in the loaded model.
:param set_grad_req_null: Set grad_req to null for model parameters.
:param forward_pass_cache_size: If > 0, cache encoder and embedding calculations of forward pass.
:return: List of models, source vocabularies, target vocabularies.
"""
source_vocabs = vocab.load_source_vocabs(model_folder)
target_vocabs = vocab.load_target_vocabs(model_folder)
model_version = utils.load_version(os.path.join(model_folder, C.VERSION_NAME))
logger.info("Model version: %s", model_version)
utils.check_version(model_version)
model_config = SockeyeModel.load_config(os.path.join(model_folder, C.CONFIG_NAME))
if inference_only and not mc_dropout:
logger.info("Disabling dropout layers for performance reasons")
model_config.disable_dropout()
if mc_dropout:
logger.info("Monte Carlo dropout enabled, inference output will be non-deterministic.")
if checkpoint is None:
params_fname = os.path.join(model_folder, C.PARAMS_BEST_NAME)
else:
params_fname = os.path.join(model_folder, C.PARAMS_NAME % checkpoint)
if os.path.exists(params_fname + '.mx'):
logger.warning(f"!!!!! Found '{params_fname}.mx' file, indicating that {params_fname} has been converted to "
"PyTorch."
f"Using '{params_fname}.mx' because behavior when loading PyTorch files is undefined.!!!!!\n")
params_fname += '.mx'
if (dtype == C.DTYPE_INT8 or
model_config.dtype == C.DTYPE_INT8 or
for_disk_saving is not None) and "intgemm_fully_connected" not in dir(npx):
# We're going to use int8 but it's not compiled into mxnet.
path = os.path.abspath(model_config.intgemm_custom_lib)
try:
mx.library.load(path)
except mx.base.MXNetError:
raise NotImplementedError("8-bit int inference requested but intgemm was not compiled into MXNet and a "
"custom operator library was not found in `%s`. Compile the custom "
"operator then set the path using intgemm_custom_lib in the config file." % path)
# Are we converting the model to 8-bit?
quantizing = model_config.dtype != C.DTYPE_INT8 and (dtype == C.DTYPE_INT8 or for_disk_saving is not None)
if quantizing:
model_config.dtype = C.DTYPE_INT8 # Ensure the scaling factor parameters are created.
model = SockeyeModel(model_config, inference_only=inference_only, train_decoder_only=train_decoder_only,
mc_dropout=mc_dropout, forward_pass_cache_size=forward_pass_cache_size)
model.initialize(ctx=context)
if model_config.dtype != C.DTYPE_INT8:
# If model_config.dtype is int8, then the above model construction
# (which also used model_config) already set everything to the correct
# mix of float32 and int8. Cast would try to make everything int8.
model.cast(model_config.dtype)
if quantizing:
logger.info("Model dtype: quantizing from float32 to int8")
allow_missing = True # The scaling factors are missing
cast_dtype = True
dtype_source = 'saved'
elif dtype is None or dtype == model_config.dtype:
logger.info("Model dtype: %s" % model_config.dtype)
allow_missing = allow_missing
cast_dtype = False
dtype_source = 'saved'
else:
logger.info("Model dtype: overridden to %s" % dtype)
model.cast(dtype)
allow_missing = allow_missing
cast_dtype = True
dtype_source = 'current'
model.load_parameters(filename=params_fname,
ctx=context,
allow_missing=allow_missing,
ignore_extra=True, # Scaling factors may be present in float32 models.
cast_dtype=cast_dtype,
dtype_source=dtype_source)
params = model.collect_params()
if set_grad_req_null:
for param in params.values():
param.grad_req = 'null'
if for_disk_saving is not None:
# Saving scaling factors and possibly int8 values to disk.
if not quantizing:
raise RuntimeError("Model is already quantized and for_disk_saving is set.")
quantization.convert_weights_disk_format(params, for_disk_saving)
model.config.dtype = for_disk_saving
# TODO: check for missing parameters somehow (we allowed scaling to be missing)
if for_disk_saving is None and model_config.dtype == C.DTYPE_INT8:
# Disk format to CPU-dependent format.
quantization.convert_weights_cpu_dependent(params)
if hybridize:
model.hybridize(static_alloc=True)
utils.check_condition(model.num_source_factors == len(source_vocabs),
"Number of loaded source vocabularies (%d) does not match "
"number of source factors for model '%s' (%d)" % (len(source_vocabs), model_folder,
model.num_source_factors))
utils.check_condition(model.num_target_factors == len(target_vocabs),
"Number of loaded target vocabularies (%d) does not match "
"number of target factors for model '%s' (%d)" % (len(target_vocabs), model_folder,
model.num_target_factors))
return model, source_vocabs, target_vocabs