def save_bare_keras_optimizer()

in horovod/spark/keras/bare.py [0:0]


def save_bare_keras_optimizer(optimizer, h5py_file):
    def get_json_type(obj):
        """Serialize any object to a JSON-serializable structure.

        # Arguments
            obj: the object to serialize

        # Returns
            JSON-serializable structure representing `obj`.

        # Raises
            TypeError: if `obj` cannot be serialized.
        """
        # if obj is a serializable Keras class instance
        # e.g. optimizer, layer
        if hasattr(obj, 'get_config'):
            return {'class_name': obj.__class__.__name__,
                    'config': obj.get_config()}

        # if obj is any numpy type
        if type(obj).__module__ == np.__name__:
            if isinstance(obj, np.ndarray):
                return obj.tolist()
            return obj.item()

        # misc functions (e.g. loss function)
        if callable(obj):
            return obj.__name__

        # if obj is a python 'type'
        if type(obj).__name__ == type.__name__:
            return obj.__name__

        raise TypeError('Not JSON Serializable: %s' % (obj,))

    if isinstance(optimizer, optimizers.TFOptimizer):
        warnings.warn(
            'TensorFlow optimizers do not '
            'make it possible to access '
            'optimizer attributes or optimizer state '
            'after instantiation. '
            'As a result, we cannot save the optimizer '
            'as part of the model save file.'
            'You will have to compile your model again '
            'after loading it. '
            'Prefer using a Keras optimizer instead '
            '(see keras.io/optimizers).')
    else:
        h5py_file['training_config'] = json.dumps({
            'optimizer_config': {
                'class_name': optimizer.__class__.__name__,
                'config': optimizer.get_config()
            },
        }, default=get_json_type).encode('utf8')

        symbolic_weights = getattr(optimizer, 'weights')
        if symbolic_weights:
            optimizer_weights_group = h5py_file['optimizer_weights']
            weight_values = K.batch_get_value(symbolic_weights)
            weight_names = []
            for i, (w, val) in enumerate(zip(symbolic_weights, weight_values)):
                if hasattr(w, 'name') and w.name:
                    name = str(w.name)
                else:
                    name = 'param_' + str(i)

                if name in weight_names:
                    idx = 2
                    unique_name = name + '_1'
                    while unique_name in weight_names:
                        unique_name = name + '_' + str(idx)
                        idx += 1
                    name = unique_name
                weight_names.append(name.encode('utf8'))
            optimizer_weights_group['weight_names'] = weight_names
            for name, val in zip(weight_names, weight_values):
                optimizer_weights_group[name] = val