in tensorflow_model_optimization/python/core/quantization/keras/collaborative_optimizations/cluster_preserve/cluster_preserve_quantize_registry.py [0:0]
def _build_clusters(self, name, layer):
"""Extracts the cluster centroids and cluster indices.
Extracts cluster centroids and cluster indices from the pretrained
clustered model when the input layer is clustered.
Args:
name: Name of weights in layer.
layer: Quantization wrapped keras layer.
Returns:
A dictionary of the initial values of the
cluster centroids, cluster indices, original weights,
the pretrained flag for marking the first training
epoch, and weight name.
"""
result = {}
weights = getattr(layer.layer, name)
if self.preserve_sparsity and not tf.reduce_any(weights == 0):
self.preserve_sparsity = False
logging.warning(
'Input layer does not contain zero weights, so apply CQAT instead.')
centroids_mask = None
centroids, lookup = get_unique(weights)
num_centroids = tf.size(centroids)
if self.preserve_sparsity:
sparsity_mask = tf.math.divide_no_nan(weights, weights)
zero_idx = tf.argmin(tf.abs(centroids), axis=-1)
centroids_mask = 1.0 - tf.one_hot(zero_idx, num_centroids)
result = {SPARSITY_MASK: sparsity_mask}
# Prepare clustering variables for the Keras graph when clusters
# exist, assuming we do not use number_of_clusters larger than 1024
if num_centroids > 1024:
return result
else:
clst_centroids_tf = layer.add_weight(
CLUSTER_CENTROIDS,
shape=centroids.shape,
initializer=tf.keras.initializers.Constant(
value=K.batch_get_value([centroids])[0]),
dtype=centroids.dtype,
trainable=True)
ori_weights_tf = layer.add_weight(
ORIGINAL_WEIGHTS,
shape=weights.shape,
initializer=tf.keras.initializers.Constant(
value=K.batch_get_value([weights])[0]),
dtype=weights.dtype,
trainable=True)
# Get clustering implementation according to layer type
clustering_impl_cls = clustering_registry.ClusteringLookupRegistry(
).get_clustering_impl(layer.layer, name)
clustering_impl = clustering_impl_cls(clst_centroids_tf)
pulling_indices = tf.dtypes.cast(
clustering_impl.get_pulling_indices(ori_weights_tf),
lookup.dtype
)
pulling_indices_tf = layer.add_weight(
PULLING_INDICES,
shape=lookup.shape,
initializer=tf.keras.initializers.Constant(
value=K.batch_get_value([pulling_indices])[0]),
dtype=lookup.dtype,
trainable=False)
result_clst = {
CLUSTER_CENTROIDS: clst_centroids_tf,
PULLING_INDICES: pulling_indices_tf,
ORIGINAL_WEIGHTS: ori_weights_tf,
WEIGHT_NAME: name,
CLUSTERING_IMPL: clustering_impl,
CENTROIDS_MASK: centroids_mask,
}
result.update(result_clst)
return result