def transform()

in tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer.py [0:0]


  def transform(self):
    """Transforms the Keras model by applying all the specified transforms.

    This is the main entry point function used to apply the transformations to
    the Keras model.

    Not suitable for multi-threaded use. Creates and manipulates internal state.

    Returns:
      (Keras model after transformation, Updated layer metadata map)
    """

    # Gets a serialized dict representation of the model, containing all its
    # layers, their connections and configuration. This is the main structure
    # which is used to understand model structure, and also manipulate it.
    #
    # config = {
    #   'input_layers': [ ... ],
    #   'layers': [{
    #       'inbound_nodes': [INPUT CONFIG OF LAYER],
    #       'name': 'LAYER_NAME',
    #       'config': { LAYER_CONFIG }
    #     }, {
    #     ...
    #   }],
    #   'output_layers': [ ... ],
    #   'name': 'MODEL_NAME',
    #
    self._config = self.model.get_config()

    # Stores map of Transform -> List of layer names matched by transform.
    # Same transform should not match+replace the same layer more than once
    # to prevent infinite loops.
    self._transform_matched_layers_map = {}
    self._layer_weights_map = {}
    self._layer_names_and_weights_map = {}

    for layer in self.model.layers:
      self._layer_weights_map[layer.name] = self._get_keras_layer_weights(layer)
      self._layer_names_and_weights_map[
          layer.name] = self._get_keras_layer_names_and_weights(layer)

    # Maintains a current mutable copy of the metadata through transformation.
    self._layer_metadata_map = copy.deepcopy(self.layer_metadata)

    # We run an infinite loop and keep applying transformations as long as
    # patterns are found. This allows recursive pattern matching where a
    # modification by one transform may lead to another match.
    #
    # TODO(pulkitb): This leads to infinite loops with poor patterns which may
    # match their replacement. Add counters with limits to fix it.
    while True:
      match_found = False
      for transform in self.transforms:
        # A transform may find multiple instances of a pattern in the model.
        # Keep finding and replacing till done.
        while True:
          match_layer_node = self._find_pattern(
              transform.pattern(), self._get_matched_layers(transform))

          # Pattern did not match any layer. Move to next transform.
          if not match_layer_node:
            break

          self._store_successful_match(transform, match_layer_node)

          # Copying the match_layer_node ensures the replacement code can
          # freely modify the match.
          replacement_layer_node = transform.replacement(
              copy.deepcopy(match_layer_node))

          # If equal, the matched layers are being replaced with exactly the
          # same set of layers that were matched with the same config.
          # For Transforms, that may inadvertently do this we can end up in
          # an infinite loop. Move on if no meaningful change has been made.
          if match_layer_node == replacement_layer_node:
            continue

          match_found = True
          self._replace(match_layer_node, replacement_layer_node)

      # None of the transforms found a pattern. We can stop now.
      if not match_found:
        break

    custom_objects = {}
    for transform in self.transforms:
      custom_objects.update(transform.custom_objects())

    # Reconstruct model from the config, using the cloned layers.
    if self._is_functional_model(self.model):
      transformed_model = keras.Model.from_config(self._config, custom_objects)
    else:
      transformed_model = keras.Sequential.from_config(self._config,
                                                       custom_objects)

    for layer in transformed_model.layers:
      weights = self._layer_weights_map.get(layer.name)
      if weights:
        self._set_layer_weights(layer, weights)
      else:
        names_and_weights = self._layer_names_and_weights_map.get(layer.name)
        if names_and_weights:
          self._set_layer_names_and_weights(layer, names_and_weights)

    return transformed_model, copy.deepcopy(self._layer_metadata_map)