def save()

in src/sagemaker_mxnet_container/training_utils.py [0:0]


def save(model_dir, model, current_host=None, hosts=None):
    """Save an MXNet Module to a given location if the current host is the scheduler host.

    This generates three files in the model directory:

    - model-symbol.json: The serialized module symbolic graph.
        Formed by invoking ``module.symbole.save``.
    - model-0000.params: The serialized module parameters.
        Formed by invoking ``module.save_params``.
    - model-shapes.json: The serialized module input data shapes in the form of a JSON list of
        JSON data-shape objects. Each data-shape object contains a string name and
        a list of integer dimensions.

    Args:
        model_dir (str): the directory for saving the model
        model (mxnet.mod.Module): the module to be saved
    """
    current_host = current_host or os.environ['SM_CURRENT_HOST']
    hosts = hosts or json.loads(os.environ['SM_HOSTS'])

    if current_host == scheduler_host(hosts):
        model.symbol.save(os.path.join(model_dir, SYMBOL_PATH))
        model.save_params(os.path.join(model_dir, PARAMS_PATH))

        signature = [{'name': data_desc.name, 'shape': [dim for dim in data_desc.shape]}
                     for data_desc in model.data_shapes]
        with open(os.path.join(model_dir, SHAPES_PATH), 'w') as f:
            json.dump(signature, f)