in tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py [0:0]
def __init__(self,
layer,
number_of_clusters,
cluster_centroids_init=CentroidInitialization.KMEANS_PLUS_PLUS,
preserve_sparsity=False,
cluster_per_channel=False,
cluster_gradient_aggregation=GradientAggregation.SUM,
**kwargs):
if not isinstance(layer, Layer):
raise ValueError(
'Please initialize `Cluster` layer with a '
'`Layer` instance. You passed: {input}'.format(input=layer))
if 'name' not in kwargs:
kwargs['name'] = self._make_layer_name(layer)
if isinstance(layer, clusterable_layer.ClusterableLayer):
# A user-defined custom layer
super(ClusterWeights, self).__init__(layer, **kwargs)
elif clustering_registry.ClusteringRegistry.supports(layer):
super(ClusterWeights, self).__init__(
clustering_registry.ClusteringRegistry.make_clusterable(layer),
**kwargs)
else:
raise ValueError(
'Please initialize `Cluster` with a supported layer. Layers should '
'either be a `ClusterableLayer` instance, or should be supported by '
'the ClusteringRegistry. You passed: {input}'.format(
input=layer.__class__))
if not isinstance(number_of_clusters, int):
raise ValueError(
'number_of_clusters must be an integer. Given: {}'.format(
number_of_clusters.__class__))
limit_number_of_clusters = 2 if preserve_sparsity else 1
if number_of_clusters <= limit_number_of_clusters:
raise ValueError(
'number_of_clusters must be greater than {}. Given: {}'.format(
limit_number_of_clusters, number_of_clusters))
self._track_trackable(layer, name='layer')
# The way how cluster centroids will be initialized
self.cluster_centroids_init = cluster_centroids_init
# The number of cluster centroids
self.number_of_clusters = number_of_clusters
# Whether to cluster Conv2D kernels per-channel.
# In case the layer isn't a Conv2D, this isn't applicable
self.cluster_per_channel = (
cluster_per_channel if isinstance(layer, tf.keras.layers.Conv2D)
else False)
# Number of channels in a Conv2D layer, to be used the case of per-channel
# clustering.
self.num_channels = None
# Whether to apply sparsity preservation or not
self.preserve_sparsity = preserve_sparsity
# The way to aggregate the gradient of each cluster centroid
self.cluster_gradient_aggregation = cluster_gradient_aggregation
# Stores the pairs of weight names and their respective sparsity masks
self.sparsity_masks = {}
# Stores the pairs of weight names and the zero centroids
self.zero_idx = {}
# Map weight names to original clusterable weights variables
# Those weights will still be updated during backpropagation
self.original_clusterable_weights = {}
# Map the position of the original weight variable in the
# child layer to the weight name
self.position_original_weights = {}
# Map weight names to corresponding clustering algorithms
self.clustering_algorithms = {}
# Map weight names to corresponding indices lookup tables
self.pulling_indices = {}
# Map weight names to corresponding cluster centroid variables
self.cluster_centroids = {}
# If the input shape was specified, then we need to preserve this
# information in the layer. If this info is not preserved, then the `built`
# state will not be preserved between serializations.
if (not hasattr(self, '_batch_input_shape') and
hasattr(layer, '_batch_input_shape')):
self._batch_input_shape = self.layer._batch_input_shape
# In the case of Conv2D layer, the data_format needs to be preserved to be
# used for per-channel clustering
if hasattr(layer, 'data_format'):
self.data_format = self.layer.data_format
else:
self.data_format = None
# Save the input shape specified in the build
self.build_input_shape = None