def make_clusterable()

in tensorflow_model_optimization/python/core/clustering/keras/clustering_registry.py [0:0]


  def make_clusterable(cls, layer):
    """Modifies a built-in layer object to support clustering.

    Args:
      layer: layer to modify for support.

    Returns:
      The modified layer object.
    """

    if not cls.supports(layer):
      raise ValueError('Layer ' + str(layer.__class__) + ' is not supported.')

    def get_clusterable_weights():
      return [(weight_name, getattr(layer, weight_name))
              for weight_name in cls._weight_names(layer)]

    def get_clusterable_weights_rnn():  # pylint: disable=missing-docstring
      def get_clusterable_weights_rnn_cell(cell, i):
        # Cell weights will be a list of tuples in RNN or
        # when are wrapped by the StackedRNNCell layer
        # The weight names will have indices attached only
        # for the registry
        if cell.__class__ in cls._SUPPORTED_RNN_CELLS:
          return [('kernel/' + str(i), cell.kernel),
                  ('recurrent_kernel/' + str(i), cell.recurrent_kernel)]

        if isinstance(cell, clusterable_layer.ClusterableLayer):
          raise ValueError(
              'ClusterableLayer is not yet supported for RNNs based layer.')

        raise ValueError('Layer cell ' + str(cell.__class__) +
                         ' is not supported.')

      clusterable_weights = []
      for rnn_cell in cls._get_rnn_cells(layer):
        if len(cls._get_rnn_cells(layer)) > 1:
          cell_index = cls._get_rnn_cells(layer).index(rnn_cell)
          clusterable_weights.extend(get_clusterable_weights_rnn_cell(
              rnn_cell, cell_index))
        else:
          clusterable_weights = get_clusterable_weights_rnn_cell(rnn_cell, 0)
      return clusterable_weights

    def get_clusterable_weights_mha():  # pylint: disable=missing-docstring
      # pylint: disable=protected-access
      return [('_query_dense.kernel', layer._query_dense.kernel),
              ('_key_dense.kernel', layer._key_dense.kernel),
              ('_value_dense.kernel', layer._value_dense.kernel),
              ('_output_dense.kernel', layer._output_dense.kernel)]

    if layer.__class__ in cls._SUPPORTED_RNN_LAYERS:
      layer.get_clusterable_weights = get_clusterable_weights_rnn
    elif layer.__class__ in cls._SUPPORTED_MHA_LAYERS:
      layer.get_clusterable_weights = get_clusterable_weights_mha
    else:
      layer.get_clusterable_weights = get_clusterable_weights

    return layer