in tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer.py [0:0]
def transform(self):
"""Transforms the Keras model by applying all the specified transforms.
This is the main entry point function used to apply the transformations to
the Keras model.
Not suitable for multi-threaded use. Creates and manipulates internal state.
Returns:
(Keras model after transformation, Updated layer metadata map)
"""
# Gets a serialized dict representation of the model, containing all its
# layers, their connections and configuration. This is the main structure
# which is used to understand model structure, and also manipulate it.
#
# config = {
# 'input_layers': [ ... ],
# 'layers': [{
# 'inbound_nodes': [INPUT CONFIG OF LAYER],
# 'name': 'LAYER_NAME',
# 'config': { LAYER_CONFIG }
# }, {
# ...
# }],
# 'output_layers': [ ... ],
# 'name': 'MODEL_NAME',
#
self._config = self.model.get_config()
# Stores map of Transform -> List of layer names matched by transform.
# Same transform should not match+replace the same layer more than once
# to prevent infinite loops.
self._transform_matched_layers_map = {}
self._layer_weights_map = {}
self._layer_names_and_weights_map = {}
for layer in self.model.layers:
self._layer_weights_map[layer.name] = self._get_keras_layer_weights(layer)
self._layer_names_and_weights_map[
layer.name] = self._get_keras_layer_names_and_weights(layer)
# Maintains a current mutable copy of the metadata through transformation.
self._layer_metadata_map = copy.deepcopy(self.layer_metadata)
# We run an infinite loop and keep applying transformations as long as
# patterns are found. This allows recursive pattern matching where a
# modification by one transform may lead to another match.
#
# TODO(pulkitb): This leads to infinite loops with poor patterns which may
# match their replacement. Add counters with limits to fix it.
while True:
match_found = False
for transform in self.transforms:
# A transform may find multiple instances of a pattern in the model.
# Keep finding and replacing till done.
while True:
match_layer_node = self._find_pattern(
transform.pattern(), self._get_matched_layers(transform))
# Pattern did not match any layer. Move to next transform.
if not match_layer_node:
break
self._store_successful_match(transform, match_layer_node)
# Copying the match_layer_node ensures the replacement code can
# freely modify the match.
replacement_layer_node = transform.replacement(
copy.deepcopy(match_layer_node))
# If equal, the matched layers are being replaced with exactly the
# same set of layers that were matched with the same config.
# For Transforms, that may inadvertently do this we can end up in
# an infinite loop. Move on if no meaningful change has been made.
if match_layer_node == replacement_layer_node:
continue
match_found = True
self._replace(match_layer_node, replacement_layer_node)
# None of the transforms found a pattern. We can stop now.
if not match_found:
break
custom_objects = {}
for transform in self.transforms:
custom_objects.update(transform.custom_objects())
# Reconstruct model from the config, using the cloned layers.
if self._is_functional_model(self.model):
transformed_model = keras.Model.from_config(self._config, custom_objects)
else:
transformed_model = keras.Sequential.from_config(self._config,
custom_objects)
for layer in transformed_model.layers:
weights = self._layer_weights_map.get(layer.name)
if weights:
self._set_layer_weights(layer, weights)
else:
names_and_weights = self._layer_names_and_weights_map.get(layer.name)
if names_and_weights:
self._set_layer_names_and_weights(layer, names_and_weights)
return transformed_model, copy.deepcopy(self._layer_metadata_map)