in tensorflow_model_optimization/python/core/sparsity/keras/prune.py [0:0]
def prune_low_magnitude(to_prune,
pruning_schedule=pruning_sched.ConstantSparsity(0.5, 0),
block_size=(1, 1),
block_pooling_type='AVG',
pruning_policy=None,
sparsity_m_by_n=None,
**kwargs):
"""Modify a tf.keras layer or model to be pruned during training.
This function wraps a tf.keras model or layer with pruning functionality which
sparsifies the layer's weights during training. For example, using this with
50% sparsity will ensure that 50% of the layer's weights are zero.
The function accepts either a single keras layer
(subclass of `tf.keras.layers.Layer`), list of keras layers or a Sequential
or Functional tf.keras model and handles them appropriately.
If it encounters a layer it does not know how to handle, it will throw an
error. While pruning an entire model, even a single unknown layer would lead
to an error.
Prune a model:
```python
pruning_params = {
'pruning_schedule': ConstantSparsity(0.5, 0),
'block_size': (1, 1),
'block_pooling_type': 'AVG'
}
model = prune_low_magnitude(
keras.Sequential([
layers.Dense(10, activation='relu', input_shape=(100,)),
layers.Dense(2, activation='sigmoid')
]), **pruning_params)
```
Prune a layer:
```python
pruning_params = {
'pruning_schedule': PolynomialDecay(initial_sparsity=0.2,
final_sparsity=0.8, begin_step=1000, end_step=2000),
'block_size': (2, 3),
'block_pooling_type': 'MAX'
}
model = keras.Sequential([
layers.Dense(10, activation='relu', input_shape=(100,)),
prune_low_magnitude(layers.Dense(2, activation='tanh'), **pruning_params)
])
```
Pretrained models: you must first load the weights and then apply the
prune API:
```python
model.load_weights(...)
model = prune_low_magnitude(model)
```
Optimizer: this function removes the optimizer. The user is expected to
compile the model
again. It's easiest to rely on the default (step starts at 0) and then
use that to determine the desired begin_step for the pruning_schedules.
Checkpointing: checkpointing should include the optimizer, not just the
weights. Pruning supports
checkpointing though
upon inspection, the weights of checkpoints are not sparse
(https://github.com/tensorflow/model-optimization/issues/206).
Arguments:
to_prune: A single keras layer, list of keras layers, or a
`tf.keras.Model` instance.
pruning_schedule: A `PruningSchedule` object that controls pruning rate
throughout training.
block_size: (optional) The dimensions (height, weight) for the block
sparse pattern in rank-2 weight tensors.
block_pooling_type: (optional) The function to use to pool weights in the
block. Must be 'AVG' or 'MAX'.
pruning_policy: (optional) The object that controls to which layers
`PruneLowMagnitude` wrapper will be applied. This API is experimental
and is subject to change.
sparsity_m_by_n: default None, otherwise a tuple of 2 integers, indicates
pruning with m_by_n sparsity, e.g., (2, 4): 2 zeros out of 4 consecutive
elements. It check whether we can do pruning with m_by_n sparsity.
If this type of sparsity is not applicable, then an error is thrown.
**kwargs: Additional keyword arguments to be passed to the keras layer.
Ignored when to_prune is not a keras layer.
Returns:
Layer or model modified with pruning wrappers. Optimizer is removed.
Raises:
ValueError: if the keras layer is unsupported, or the keras model contains
an unsupported layer.
"""
def _prune_list(layers, **params):
wrapped_layers = []
for layer in layers:
# Allow layer that is already wrapped by the pruning wrapper
# to be used as is.
# No need to wrap the input layer either.
if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
wrapped_layers.append(layer)
elif isinstance(layer, keras.layers.InputLayer):
# TODO(yunluli): Replace with a clone function in keras.
wrapped_layers.append(layer.__class__.from_config(layer.get_config()))
else:
wrapped_layers.append(
pruning_wrapper.PruneLowMagnitude(layer, **params))
return wrapped_layers
def _add_pruning_wrapper(layer):
if isinstance(layer, keras.Model):
# Check whether the model is a subclass model.
if (not layer._is_graph_network and
not isinstance(layer, keras.models.Sequential)):
raise ValueError('Subclassed models are not supported currently.')
return keras.models.clone_model(
layer, input_tensors=None, clone_function=_add_pruning_wrapper)
if isinstance(layer, pruning_wrapper.PruneLowMagnitude):
return layer
if pruning_policy and not pruning_policy.allow_pruning(layer):
return layer
else:
return pruning_wrapper.PruneLowMagnitude(layer, **params)
params = {
'pruning_schedule': pruning_schedule,
'block_size': block_size,
'block_pooling_type': block_pooling_type,
'sparsity_m_by_n': sparsity_m_by_n,
}
is_sequential_or_functional = isinstance(
to_prune, keras.Model) and (isinstance(to_prune, keras.Sequential) or
to_prune._is_graph_network)
# A subclassed model is also a subclass of keras.layers.Layer.
is_keras_layer = isinstance(
to_prune, keras.layers.Layer) and not isinstance(to_prune, keras.Model)
if isinstance(to_prune, list):
return _prune_list(to_prune, **params)
elif is_sequential_or_functional:
if pruning_policy:
pruning_policy.ensure_model_supports_pruning(to_prune)
return _add_pruning_wrapper(to_prune)
elif is_keras_layer:
params.update(kwargs)
return pruning_wrapper.PruneLowMagnitude(to_prune, **params)
else:
raise ValueError(
'`prune_low_magnitude` can only prune an object of the following '
'types: tf.keras.models.Sequential, tf.keras functional model, '
'tf.keras.layers.Layer, list of tf.keras.layers.Layer. You passed '
'an object of type: {input}.'.format(input=to_prune.__class__.__name__))