def __init__()

in tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py [0:0]


  def __init__(self,
               layer,
               number_of_clusters,
               cluster_centroids_init=CentroidInitialization.KMEANS_PLUS_PLUS,
               preserve_sparsity=False,
               cluster_per_channel=False,
               cluster_gradient_aggregation=GradientAggregation.SUM,
               **kwargs):
    if not isinstance(layer, Layer):
      raise ValueError(
          'Please initialize `Cluster` layer with a '
          '`Layer` instance. You passed: {input}'.format(input=layer))

    if 'name' not in kwargs:
      kwargs['name'] = self._make_layer_name(layer)

    if isinstance(layer, clusterable_layer.ClusterableLayer):
      # A user-defined custom layer
      super(ClusterWeights, self).__init__(layer, **kwargs)
    elif clustering_registry.ClusteringRegistry.supports(layer):
      super(ClusterWeights, self).__init__(
          clustering_registry.ClusteringRegistry.make_clusterable(layer),
          **kwargs)
    else:
      raise ValueError(
          'Please initialize `Cluster` with a supported layer. Layers should '
          'either be a `ClusterableLayer` instance, or should be supported by '
          'the ClusteringRegistry. You passed: {input}'.format(
              input=layer.__class__))

    if not isinstance(number_of_clusters, int):
      raise ValueError(
          'number_of_clusters must be an integer. Given: {}'.format(
              number_of_clusters.__class__))

    limit_number_of_clusters = 2 if preserve_sparsity else 1
    if number_of_clusters <= limit_number_of_clusters:
      raise ValueError(
          'number_of_clusters must be greater than {}. Given: {}'.format(
              limit_number_of_clusters, number_of_clusters))

    self._track_trackable(layer, name='layer')

    # The way how cluster centroids will be initialized
    self.cluster_centroids_init = cluster_centroids_init

    # The number of cluster centroids
    self.number_of_clusters = number_of_clusters

    # Whether to cluster Conv2D kernels per-channel.
    # In case the layer isn't a Conv2D, this isn't applicable
    self.cluster_per_channel = (
        cluster_per_channel if isinstance(layer, tf.keras.layers.Conv2D)
        else False)

    # Number of channels in a Conv2D layer, to be used the case of per-channel
    # clustering.
    self.num_channels = None

    # Whether to apply sparsity preservation or not
    self.preserve_sparsity = preserve_sparsity

    # The way to aggregate the gradient of each cluster centroid
    self.cluster_gradient_aggregation = cluster_gradient_aggregation

    # Stores the pairs of weight names and their respective sparsity masks
    self.sparsity_masks = {}

    # Stores the pairs of weight names and the zero centroids
    self.zero_idx = {}

    # Map weight names to original clusterable weights variables
    # Those weights will still be updated during backpropagation
    self.original_clusterable_weights = {}

    # Map the position of the original weight variable in the
    # child layer to the weight name
    self.position_original_weights = {}

    # Map weight names to corresponding clustering algorithms
    self.clustering_algorithms = {}

    # Map weight names to corresponding indices lookup tables
    self.pulling_indices = {}

    # Map weight names to corresponding cluster centroid variables
    self.cluster_centroids = {}

    # If the input shape was specified, then we need to preserve this
    # information in the layer. If this info is not preserved, then the `built`
    # state will not be preserved between serializations.
    if (not hasattr(self, '_batch_input_shape') and
        hasattr(layer, '_batch_input_shape')):
      self._batch_input_shape = self.layer._batch_input_shape

    # In the case of Conv2D layer, the data_format needs to be preserved to be
    # used for per-channel clustering
    if hasattr(layer, 'data_format'):
      self.data_format = self.layer.data_format
    else:
      self.data_format = None

    # Save the input shape specified in the build
    self.build_input_shape = None