in tensorflow_model_optimization/python/core/clustering/keras/cluster_wrapper.py [0:0]
def build(self, input_shape):
super(ClusterWeights, self).build(input_shape)
self.build_input_shape = input_shape
# For every clusterable weights, create the clustering logic
for weight_name, weight in self.layer.get_clusterable_weights():
# Store the original weight in this wrapper
# The child reference will be overridden in
# update_clustered_weights_associations
# The actual weight_name here for the clustering wrapper is not
# necessarily the same as the original one from the layer wrapped.
# For example for cells in StackedRNNCell, the names become
# 'kernel/0', 'recurrent_kernel/0', 'kernel/1', 'recurrent_kernel/1'
original_weight = self.get_weight_from_layer(weight_name)
self.original_clusterable_weights[weight_name] = original_weight
# Track the variable
setattr(self, 'original_weight_' + weight_name,
original_weight)
# Store the position in layer.weights of original_weight to restore during
# stripping
position_original_weight = next(
i for i, w in enumerate(self.layer.weights) if w is original_weight)
self.position_original_weights[position_original_weight] = weight_name
# In the case of per-channel clustering, the number of channels,
# per-channel number of clusters, as well as the overall number
# of clusters all need to be preserved in the wrapper.
if self.cluster_per_channel:
self.num_channels = (
original_weight.shape[1] if self.data_format == 'channels_first'
else original_weight.shape[-1])
centroid_init_factory = clustering_centroids.CentroidsInitializerFactory
centroid_init = centroid_init_factory.get_centroid_initializer(
self.cluster_centroids_init)(weight, self.number_of_clusters,
self.cluster_per_channel,
self.num_channels,
self.preserve_sparsity)
# Init the cluster centroids
cluster_centroids = (centroid_init.get_cluster_centroids())
self.cluster_centroids[weight_name] = self.add_weight(
'{}{}'.format('cluster_centroids_', weight_name),
shape=(cluster_centroids.shape),
dtype=weight.dtype,
trainable=True,
initializer=tf.keras.initializers.Constant(value=cluster_centroids))
# Init the weight clustering algorithm
if isinstance(self.layer, tf.keras.layers.RNN):
if isinstance(self.layer.cell, tf.keras.layers.StackedRNNCells):
weight_name_no_index = weight_name.split('/')[0]
else:
weight_name_no_index = weight_name
elif isinstance(self.layer, tf.keras.layers.Bidirectional):
weight_name_no_index = weight_name.split('/')[0]
else:
weight_name_no_index = weight_name
self.clustering_algorithms[weight_name] = (
clustering_registry.ClusteringLookupRegistry().get_clustering_impl(
self.layer, weight_name_no_index, self.cluster_per_channel)
(
clusters_centroids=self.cluster_centroids[weight_name],
cluster_gradient_aggregation=self.cluster_gradient_aggregation,
data_format=self.data_format,
))
# Init the pulling_indices (weights associations)
pulling_indices = (
self.clustering_algorithms[weight_name].get_pulling_indices(
weight))
self.pulling_indices[weight_name] = self.add_weight(
'{}{}'.format('pulling_indices_', weight_name),
shape=pulling_indices.shape,
dtype=tf.int64,
trainable=False,
synchronization=tf.VariableSynchronization.ON_READ,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
initializer=tf.keras.initializers.Constant(value=pulling_indices))
if self.preserve_sparsity:
# Init the sparsity mask
clustered_weights = (
self.clustering_algorithms[weight_name].get_clustered_weight(
pulling_indices, original_weight))
self.sparsity_masks[weight_name] = (
tf.cast(tf.math.not_equal(clustered_weights, 0), dtype=tf.float32))
# If the model is pruned (which we suppose), this is approximately zero
self.zero_idx[weight_name] = tf.argmin(
tf.abs(self.cluster_centroids[weight_name]), axis=-1)