in tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py [0:0]
def make_clusterable(cls, layer):
"""Modifies a built-in layer object to support clustering.
Args:
layer: layer to modify for support.
Returns:
The modified layer object.
"""
if not cls.supports(layer):
raise ValueError('Layer ' + str(layer.__class__) + ' is not supported.')
def get_clusterable_weights():
return [(weight_name, getattr(layer, weight_name))
for weight_name in cls._weight_names(layer)]
def get_clusterable_weights_rnn(): # pylint: disable=missing-docstring
def get_clusterable_weights_rnn_cell(cell, i):
# Cell weights will be a list of tuples in RNN or
# when are wrapped by the StackedRNNCell layer
# The weight names will have indices attached only
# for the registry
if cell.__class__ in cls._SUPPORTED_RNN_CELLS:
return [('kernel/' + str(i), cell.kernel),
('recurrent_kernel/' + str(i), cell.recurrent_kernel)]
if isinstance(cell, clusterable_layer.ClusterableLayer):
raise ValueError(
'ClusterableLayer is not yet supported for RNNs based layer.')
raise ValueError('Layer cell ' + str(cell.__class__) +
' is not supported.')
clusterable_weights = []
for rnn_cell in cls._get_rnn_cells(layer):
if len(cls._get_rnn_cells(layer)) > 1:
cell_index = cls._get_rnn_cells(layer).index(rnn_cell)
clusterable_weights.extend(get_clusterable_weights_rnn_cell(
rnn_cell, cell_index))
else:
clusterable_weights = get_clusterable_weights_rnn_cell(rnn_cell, 0)
return clusterable_weights
def get_clusterable_weights_mha(): # pylint: disable=missing-docstring
# pylint: disable=protected-access
return [('_query_dense.kernel', layer._query_dense.kernel),
('_key_dense.kernel', layer._key_dense.kernel),
('_value_dense.kernel', layer._value_dense.kernel),
('_output_dense.kernel', layer._output_dense.kernel)]
if layer.__class__ in cls._SUPPORTED_RNN_LAYERS:
layer.get_clusterable_weights = get_clusterable_weights_rnn
elif layer.__class__ in cls._SUPPORTED_MHA_LAYERS:
layer.get_clusterable_weights = get_clusterable_weights_mha
else:
layer.get_clusterable_weights = get_clusterable_weights
return layer