def _cluster_weights()

in tensorflow_model_optimization/python/core/clustering/keras/cluster.py [0:0]


def _cluster_weights(to_cluster,
                     number_of_clusters,
                     cluster_centroids_init,
                     preserve_sparsity=False,
                     cluster_per_channel=False,
                     **kwargs):
  """Modifies a keras layer or model to be clustered during training.

  This function wraps a keras model or layer with clustering functionality
  which clusters the layer's weights during training. For examples, using
  this with number_of_clusters equals 8 will ensure that each weight tensor has
  no more than 8 unique values.

  Before passing to the clustering API, a model should already be trained and
  show some acceptable performance on the testing/validation sets.

  The function accepts either a single keras layer
  (subclass of `keras.layers.Layer`), list of keras layers or a keras model
  (instance of `keras.models.Model`) and handles them appropriately.

  If it encounters a layer it does not know how to handle, it will throw an
  error. While clustering an entire model, even a single unknown layer would
  lead to an error.

  Cluster a model:

  ```python
  clustering_params = {
    'number_of_clusters': 8,
    'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
    'cluster_per_channel': False,
    'preserve_sparsity': False
  }

  clustered_model = cluster_weights(original_model, **clustering_params)
  ```

  Cluster a layer:

  ```python
  clustering_params = {
    'number_of_clusters': 8,
    'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
    'cluster_per_channel': False,
    'preserve_sparsity': False
  }

  model = tf.keras.Sequential([
      layers.Dense(10, activation='relu', input_shape=(100,)),
      cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
  ])
  ```

  Cluster a layer with sparsity preservation (experimental):

  ```python
  clustering_params = {
    'number_of_clusters': 8,
    'cluster_centroids_init': CentroidInitialization.DENSITY_BASED,
    'preserve_sparsity': True
  }

  model = tf.keras.Sequential([
      layers.Dense(10, activation='relu', input_shape=(100,)),
      cluster_weights(layers.Dense(2, activation='tanh'), **clustering_params)
  ])
  ```

  Arguments:
      to_cluster: A single keras layer, list of keras layers, or a
        `tf.keras.Model` instance.
      number_of_clusters: the number of cluster centroids to form when
        clustering a layer/model. For example, if number_of_clusters=8 then only
        8 unique values will be used in each weight array.
      cluster_centroids_init: `tfmot.clustering.keras.CentroidInitialization`
        instance that determines how the cluster centroids will be initialized.
      preserve_sparsity (experimental): optional boolean value that determines
        whether or not sparsity preservation will be enforced during training.
        When used along with cluster_per_channel flag below, the zero centroid
        is treated separately and maintained individually for each channel.
      cluster_per_channel: optional boolean value that determines whether the
        clustering should be applied separately on the individual channels, as
        opposed to the whole kernel. Only applicable to Conv2D layers and is
        ignored otherwise. The number of clusters in this case would be
        num_clusters*num_channels. This is useful for the collaborative
        optimization pipeline where clustering is followed by quantization,
        since Conv2D is quantized per-channel, so we end up with
        num_clusters*num_channels total clusters at the end. Clustering
        per-channel from the beginning leads to better accuracy.
      **kwargs: Additional keyword arguments to be passed to the keras layer.
        Ignored when to_cluster is not a keras layer.

  Returns:
    Layer or model modified to include clustering related metadata.

  Raises:
    ValueError: if the keras layer is unsupported, or the keras model contains
    an unsupported layer.
  """
  if not clustering_centroids.CentroidsInitializerFactory.init_is_supported(
      cluster_centroids_init):
    raise ValueError('Cluster centroid initialization {} not supported'.format(
        cluster_centroids_init))

  def _add_clustering_wrapper(layer):
    if isinstance(layer, tf.keras.Model):
      # Check whether the model is a subclass.
      # NB: This check is copied from keras.py file in tensorflow.
      # There is no available public API to do this check.
      # pylint: disable=protected-access
      if (not layer._is_graph_network and
          not isinstance(layer, tf.keras.models.Sequential)):
        raise ValueError('Subclassed models are not supported currently.')

      return tf.keras.models.clone_model(
          layer, input_tensors=None, clone_function=_add_clustering_wrapper)
    if isinstance(layer, cluster_wrapper.ClusterWeights):
      return layer
    if isinstance(layer, InputLayer):
      return layer.__class__.from_config(layer.get_config())
    if isinstance(layer, tf.keras.layers.RNN) or isinstance(
        layer, tf.keras.layers.Bidirectional):
      return cluster_wrapper.ClusterWeightsRNN(
          layer,
          number_of_clusters,
          cluster_centroids_init,
          preserve_sparsity,
          **kwargs,
      )
    if isinstance(layer, tf.keras.layers.MultiHeadAttention):
      return cluster_wrapper.ClusterWeightsMHA(
          layer,
          number_of_clusters,
          cluster_centroids_init,
          preserve_sparsity,
          **kwargs,
      )

    return cluster_wrapper.ClusterWeights(layer, number_of_clusters,
                                          cluster_centroids_init,
                                          preserve_sparsity,
                                          cluster_per_channel, **kwargs)

  def _wrap_list(layers):
    output = []
    for layer in layers:
      output.append(_add_clustering_wrapper(layer))

    return output

  if isinstance(to_cluster, tf.keras.Model):
    return tf.keras.models.clone_model(
        to_cluster, input_tensors=None, clone_function=_add_clustering_wrapper)
  if isinstance(to_cluster, Layer):
    return _add_clustering_wrapper(layer=to_cluster)
  if isinstance(to_cluster, list):
    return _wrap_list(to_cluster)