def build()

in tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py [0:0]


  def build(self, input_shape):
    super(ClusterWeights, self).build(input_shape)
    self.build_input_shape = input_shape

    # For every clusterable weights, create the clustering logic
    for weight_name, weight in self.layer.get_clusterable_weights():
      # Store the original weight in this wrapper
      # The child reference will be overridden in
      # update_clustered_weights_associations
      # The actual weight_name here for the clustering wrapper is not
      # necessarily the same as the original one from the layer wrapped.
      # For example for cells in StackedRNNCell, the names become
      # 'kernel/0', 'recurrent_kernel/0', 'kernel/1', 'recurrent_kernel/1'
      original_weight = self.get_weight_from_layer(weight_name)
      self.original_clusterable_weights[weight_name] = original_weight
      # Track the variable
      setattr(self, 'original_weight_' + weight_name,
              original_weight)
      # Store the position in layer.weights of original_weight to restore during
      # stripping
      position_original_weight = next(
          i for i, w in enumerate(self.layer.weights) if w is original_weight)
      self.position_original_weights[position_original_weight] = weight_name

      # In the case of per-channel clustering, the number of channels,
      # per-channel number of clusters, as well as the overall number
      # of clusters all need to be preserved in the wrapper.
      if self.cluster_per_channel:
        self.num_channels = (
            original_weight.shape[1] if self.data_format == 'channels_first'
            else original_weight.shape[-1])

      centroid_init_factory = clustering_centroids.CentroidsInitializerFactory
      centroid_init = centroid_init_factory.get_centroid_initializer(
          self.cluster_centroids_init)(weight, self.number_of_clusters,
                                       self.cluster_per_channel,
                                       self.num_channels,
                                       self.preserve_sparsity)

      # Init the cluster centroids
      cluster_centroids = (centroid_init.get_cluster_centroids())

      self.cluster_centroids[weight_name] = self.add_weight(
          '{}{}'.format('cluster_centroids_', weight_name),
          shape=(cluster_centroids.shape),
          dtype=weight.dtype,
          trainable=True,
          initializer=tf.keras.initializers.Constant(value=cluster_centroids))

      # Init the weight clustering algorithm
      if isinstance(self.layer, tf.keras.layers.RNN):
        if isinstance(self.layer.cell, tf.keras.layers.StackedRNNCells):
          weight_name_no_index = weight_name.split('/')[0]
        else:
          weight_name_no_index = weight_name
      elif isinstance(self.layer, tf.keras.layers.Bidirectional):
        weight_name_no_index = weight_name.split('/')[0]
      else:
        weight_name_no_index = weight_name
      self.clustering_algorithms[weight_name] = (
          clustering_registry.ClusteringLookupRegistry().get_clustering_impl(
              self.layer, weight_name_no_index, self.cluster_per_channel)
          (
              clusters_centroids=self.cluster_centroids[weight_name],
              cluster_gradient_aggregation=self.cluster_gradient_aggregation,
              data_format=self.data_format,
          ))

      # Init the pulling_indices (weights associations)
      pulling_indices = (
          self.clustering_algorithms[weight_name].get_pulling_indices(
              weight))
      self.pulling_indices[weight_name] = self.add_weight(
          '{}{}'.format('pulling_indices_', weight_name),
          shape=pulling_indices.shape,
          dtype=tf.int64,
          trainable=False,
          synchronization=tf.VariableSynchronization.ON_READ,
          aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
          initializer=tf.keras.initializers.Constant(value=pulling_indices))

      if self.preserve_sparsity:
        # Init the sparsity mask
        clustered_weights = (
            self.clustering_algorithms[weight_name].get_clustered_weight(
                pulling_indices, original_weight))
        self.sparsity_masks[weight_name] = (
            tf.cast(tf.math.not_equal(clustered_weights, 0), dtype=tf.float32))
        # If the model is pruned (which we suppose), this is approximately zero
        self.zero_idx[weight_name] = tf.argmin(
            tf.abs(self.cluster_centroids[weight_name]), axis=-1)