def load_model()

in sockeye/model.py [0:0]


def load_model(model_folder: str,
               context: Union[List[mx.context.Context], mx.context.Context] = mx.cpu(),
               dtype: Optional[str] = None,
               checkpoint: Optional[int] = None,
               hybridize: bool = True,
               inference_only: bool = False,
               train_decoder_only: bool = False,
               mc_dropout: bool = False,
               for_disk_saving: Optional[str] = None,
               allow_missing: bool = False,
               set_grad_req_null: bool = True,
               forward_pass_cache_size: int = 0) -> Tuple[SockeyeModel, List[vocab.Vocab], List[vocab.Vocab]]:
    """
    Load a model from model_folder.

    :param model_folder: Model folder.
    :param context: MXNet context to bind modules to.
    :param checkpoint: Checkpoint to use. If none, uses best checkpoint.
    :param dtype: Optional data type to use. If None, will be inferred from stored model.
    :param hybridize: Whether to hybridize the loaded models. Default: true.
    :param inference_only: Use the model only for inference, enabling optimizations.
    :param train_decoder_only: Training will only update the decoder. Disable
           autograd for encoder and embeddings to save memory.
    :param mc_dropout: Turn on dropout during inference.
    :param for_disk_saving: For saving quantized models to disk.
           None: load as usual and the model will work.
           int8: The model loaded into RAM will not work, but is suitable for
               writing to disk in quantized format (including scaling factors).
           float32: The model loaded into RAM will not work, but is suitable
               for writing to disk as float32 with precomputed scaling factors.
    :param allow_missing: Allow missing parameters in the loaded model.
    :param set_grad_req_null: Set grad_req to null for model parameters.
    :param forward_pass_cache_size: If > 0, cache encoder and embedding calculations of forward pass.
    :return: List of models, source vocabularies, target vocabularies.
    """
    source_vocabs = vocab.load_source_vocabs(model_folder)
    target_vocabs = vocab.load_target_vocabs(model_folder)
    model_version = utils.load_version(os.path.join(model_folder, C.VERSION_NAME))
    logger.info("Model version: %s", model_version)
    utils.check_version(model_version)
    model_config = SockeyeModel.load_config(os.path.join(model_folder, C.CONFIG_NAME))

    if inference_only and not mc_dropout:
        logger.info("Disabling dropout layers for performance reasons")
        model_config.disable_dropout()

    if mc_dropout:
        logger.info("Monte Carlo dropout enabled, inference output will be non-deterministic.")

    if checkpoint is None:
        params_fname = os.path.join(model_folder, C.PARAMS_BEST_NAME)
    else:
        params_fname = os.path.join(model_folder, C.PARAMS_NAME % checkpoint)

    if os.path.exists(params_fname + '.mx'):
        logger.warning(f"!!!!! Found '{params_fname}.mx' file, indicating that {params_fname} has been converted to "
                       "PyTorch."
                       f"Using '{params_fname}.mx' because behavior when loading PyTorch files is undefined.!!!!!\n")
        params_fname += '.mx'

    if (dtype == C.DTYPE_INT8 or
        model_config.dtype == C.DTYPE_INT8 or
        for_disk_saving is not None) and "intgemm_fully_connected" not in dir(npx):
        # We're going to use int8 but it's not compiled into mxnet.
        path = os.path.abspath(model_config.intgemm_custom_lib)
        try:
            mx.library.load(path)
        except mx.base.MXNetError:
            raise NotImplementedError("8-bit int inference requested but intgemm was not compiled into MXNet and a "
                                      "custom operator library was not found in `%s`.  Compile the custom "
                                      "operator then set the path using intgemm_custom_lib in the config file." % path)

    # Are we converting the model to 8-bit?
    quantizing = model_config.dtype != C.DTYPE_INT8 and (dtype == C.DTYPE_INT8 or for_disk_saving is not None)
    if quantizing:
        model_config.dtype = C.DTYPE_INT8  # Ensure the scaling factor parameters are created.

    model = SockeyeModel(model_config, inference_only=inference_only, train_decoder_only=train_decoder_only,
                         mc_dropout=mc_dropout, forward_pass_cache_size=forward_pass_cache_size)
    model.initialize(ctx=context)
    if model_config.dtype != C.DTYPE_INT8:
        # If model_config.dtype is int8, then the above model construction
        # (which also used model_config) already set everything to the correct
        # mix of float32 and int8.  Cast would try to make everything int8.
        model.cast(model_config.dtype)

    if quantizing:
        logger.info("Model dtype: quantizing from float32 to int8")
        allow_missing = True  # The scaling factors are missing
        cast_dtype = True
        dtype_source = 'saved'
    elif dtype is None or dtype == model_config.dtype:
        logger.info("Model dtype: %s" % model_config.dtype)
        allow_missing = allow_missing
        cast_dtype = False
        dtype_source = 'saved'
    else:
        logger.info("Model dtype: overridden to %s" % dtype)
        model.cast(dtype)
        allow_missing = allow_missing
        cast_dtype = True
        dtype_source = 'current'

    model.load_parameters(filename=params_fname,
                          ctx=context,
                          allow_missing=allow_missing,
                          ignore_extra=True,  # Scaling factors may be present in float32 models.
                          cast_dtype=cast_dtype,
                          dtype_source=dtype_source)

    params = model.collect_params()
    if set_grad_req_null:
        for param in params.values():
            param.grad_req = 'null'

    if for_disk_saving is not None:
        # Saving scaling factors and possibly int8 values to disk.
        if not quantizing:
            raise RuntimeError("Model is already quantized and for_disk_saving is set.")
        quantization.convert_weights_disk_format(params, for_disk_saving)
        model.config.dtype = for_disk_saving
        # TODO: check for missing parameters somehow (we allowed scaling to be missing)
    if for_disk_saving is None and model_config.dtype == C.DTYPE_INT8:
        # Disk format to CPU-dependent format.
        quantization.convert_weights_cpu_dependent(params)

    if hybridize:
        model.hybridize(static_alloc=True)

    utils.check_condition(model.num_source_factors == len(source_vocabs),
                          "Number of loaded source vocabularies (%d) does not match "
                          "number of source factors for model '%s' (%d)" % (len(source_vocabs), model_folder,
                                                                            model.num_source_factors))
    utils.check_condition(model.num_target_factors == len(target_vocabs),
                          "Number of loaded target vocabularies (%d) does not match "
                          "number of target factors for model '%s' (%d)" % (len(target_vocabs), model_folder,
                                                                            model.num_target_factors))
    return model, source_vocabs, target_vocabs