def _build_clusters()

in tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/cluster_preserve/cluster_preserve_quantize_registry.py [0:0]


  def _build_clusters(self, name, layer):
    """Extracts the cluster centroids and cluster indices.

    Extracts cluster centroids and cluster indices from the pretrained
    clustered model when the input layer is clustered.

    Args:
      name: Name of weights in layer.
      layer: Quantization wrapped keras layer.
    Returns:
      A dictionary of the initial values of the
      cluster centroids, cluster indices, original weights,
      the pretrained flag for marking the first training
      epoch, and weight name.
    """
    result = {}
    weights = getattr(layer.layer, name)
    if self.preserve_sparsity and not tf.reduce_any(weights == 0):
      self.preserve_sparsity = False
      logging.warning(
          'Input layer does not contain zero weights, so apply CQAT instead.')
    centroids_mask = None
    centroids, lookup = get_unique(weights)
    num_centroids = tf.size(centroids)

    if self.preserve_sparsity:
      sparsity_mask = tf.math.divide_no_nan(weights, weights)
      zero_idx = tf.argmin(tf.abs(centroids), axis=-1)
      centroids_mask = 1.0 - tf.one_hot(zero_idx, num_centroids)
      result = {SPARSITY_MASK: sparsity_mask}

    # Prepare clustering variables for the Keras graph when clusters
    # exist, assuming we do not use number_of_clusters larger than 1024
    if num_centroids > 1024:
      return result
    else:
      clst_centroids_tf = layer.add_weight(
          CLUSTER_CENTROIDS,
          shape=centroids.shape,
          initializer=tf.keras.initializers.Constant(
              value=K.batch_get_value([centroids])[0]),
          dtype=centroids.dtype,
          trainable=True)

      ori_weights_tf = layer.add_weight(
          ORIGINAL_WEIGHTS,
          shape=weights.shape,
          initializer=tf.keras.initializers.Constant(
              value=K.batch_get_value([weights])[0]),
          dtype=weights.dtype,
          trainable=True)

      # Get clustering implementation according to layer type
      clustering_impl_cls = clustering_registry.ClusteringLookupRegistry(
          ).get_clustering_impl(layer.layer, name)
      clustering_impl = clustering_impl_cls(clst_centroids_tf)

      pulling_indices = tf.dtypes.cast(
          clustering_impl.get_pulling_indices(ori_weights_tf),
          lookup.dtype
      )

      pulling_indices_tf = layer.add_weight(
          PULLING_INDICES,
          shape=lookup.shape,
          initializer=tf.keras.initializers.Constant(
              value=K.batch_get_value([pulling_indices])[0]),
          dtype=lookup.dtype,
          trainable=False)

      result_clst = {
          CLUSTER_CENTROIDS: clst_centroids_tf,
          PULLING_INDICES: pulling_indices_tf,
          ORIGINAL_WEIGHTS: ori_weights_tf,
          WEIGHT_NAME: name,
          CLUSTERING_IMPL: clustering_impl,
          CENTROIDS_MASK: centroids_mask,
      }
      result.update(result_clst)
      return result