in tensorflow_model_optimization/python/core/clustering/keras/cluster.py [0:0]
def _cluster_weights(to_cluster,
number_of_clusters,
cluster_centroids_init,
preserve_sparsity=False,
cluster_per_channel=False,
**kwargs):
"""Modifies a keras layer or model to be clustered during training.
This function wraps a keras model or layer with clustering functionality
which clusters the layer's weights during training. For examples, using
this with number_of_clusters equals 8 will ensure that each weight tensor has
no more than 8 unique values.
Before passing to the clustering API, a model should already be trained and
show some acceptable performance on the testing/validation sets.
The function accepts either a single keras layer
(subclass of `keras.layers.Layer`), list of keras layers or a keras model
(instance of `keras.models.Model`) and handles them appropriately.
If it encounters a layer it does not know how to handle, it will throw an
error. While clustering an entire model, even a single unknown layer would
lead to an error.
Cluster a model:
```python
clustering_params = {
'number_of_clusters': 8,
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
'cluster_per_channel': False,
'preserve_sparsity': False
}
clustered_model = cluster_weights(original_model, **clustering_params)
```
Cluster a layer:
```python
clustering_params = {
'number_of_clusters': 8,
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
'cluster_per_channel': False,
'preserve_sparsity': False
}
model = tf.keras.Sequential([
layers.Dense(10, activation='relu', input_shape=(100,)),
cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
])
```
Cluster a layer with sparsity preservation (experimental):
```python
clustering_params = {
'number_of_clusters': 8,
'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
'preserve_sparsity': True
}
model = tf.keras.Sequential([
layers.Dense(10, activation='relu', input_shape=(100,)),
cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
])
```
Arguments:
to_cluster: A single keras layer, list of keras layers, or a
`tf.keras.Model` instance.
number_of_clusters: the number of cluster centroids to form when
clustering a layer/model. For example, if number_of_clusters=8 then only
8 unique values will be used in each weight array.
cluster_centroids_init: `tfmot.clustering.keras.CentroidInitialization`
instance that determines how the cluster centroids will be initialized.
preserve_sparsity (experimental): optional boolean value that determines
whether or not sparsity preservation will be enforced during training.
When used along with cluster_per_channel flag below, the zero centroid
is treated separately and maintained individually for each channel.
cluster_per_channel: optional boolean value that determines whether the
clustering should be applied separately on the individual channels, as
opposed to the whole kernel. Only applicable to Conv2D layers and is
ignored otherwise. The number of clusters in this case would be
num_clusters*num_channels. This is useful for the collaborative
optimization pipeline where clustering is followed by quantization,
since Conv2D is quantized per-channel, so we end up with
num_clusters*num_channels total clusters at the end. Clustering
per-channel from the beginning leads to better accuracy.
**kwargs: Additional keyword arguments to be passed to the keras layer.
Ignored when to_cluster is not a keras layer.
Returns:
Layer or model modified to include clustering related metadata.
Raises:
ValueError: if the keras layer is unsupported, or the keras model contains
an unsupported layer.
"""
if not clustering_centroids.CentroidsInitializerFactory.init_is_supported(
cluster_centroids_init):
raise ValueError('Cluster centroid initialization {} not supported'.format(
cluster_centroids_init))
def _add_clustering_wrapper(layer):
if isinstance(layer, tf.keras.Model):
# Check whether the model is a subclass.
# NB: This check is copied from keras.py file in tensorflow.
# There is no available public API to do this check.
# pylint: disable=protected-access
if (not layer._is_graph_network and
not isinstance(layer, tf.keras.models.Sequential)):
raise ValueError('Subclassed models are not supported currently.')
return tf.keras.models.clone_model(
layer, input_tensors=None, clone_function=_add_clustering_wrapper)
if isinstance(layer, cluster_wrapper.ClusterWeights):
return layer
if isinstance(layer, InputLayer):
return layer.__class__.from_config(layer.get_config())
if isinstance(layer, tf.keras.layers.RNN) or isinstance(
layer, tf.keras.layers.Bidirectional):
return cluster_wrapper.ClusterWeightsRNN(
layer,
number_of_clusters,
cluster_centroids_init,
preserve_sparsity,
**kwargs,
)
if isinstance(layer, tf.keras.layers.MultiHeadAttention):
return cluster_wrapper.ClusterWeightsMHA(
layer,
number_of_clusters,
cluster_centroids_init,
preserve_sparsity,
**kwargs,
)
return cluster_wrapper.ClusterWeights(layer, number_of_clusters,
cluster_centroids_init,
preserve_sparsity,
cluster_per_channel, **kwargs)
def _wrap_list(layers):
output = []
for layer in layers:
output.append(_add_clustering_wrapper(layer))
return output
if isinstance(to_cluster, tf.keras.Model):
return tf.keras.models.clone_model(
to_cluster, input_tensors=None, clone_function=_add_clustering_wrapper)
if isinstance(to_cluster, Layer):
return _add_clustering_wrapper(layer=to_cluster)
if isinstance(to_cluster, list):
return _wrap_list(to_cluster)