in tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer.py [0:0]
def _replace_functional(self, match_layer_node, replacement_layer_node):
"""Functional model: replace the tree of match_layer_node with replacement_layer_node."""
# 1. Point all consumers of the head of the matching sub-tree to the head
# replacement layer.
#
# There are some assumptions baked in. The head layer only has 1 inbound and
# outbound node. The resulting number and shape of tensors from the
# replaced layer should equal the original layer.
consuming_layers = self._get_consuming_layers(match_layer_node.layer)
match_name = match_layer_node.layer['config']['name']
replacement_name = replacement_layer_node.layer['config']['name']
def _replace_layer_name_for_connection_info(connection_info, match_name,
replacement_name):
if connection_info[0] == match_name:
connection_info[0] = replacement_name
for key in connection_info[3]:
if isinstance(connection_info[3][key], list):
if connection_info[3][key][0] == match_name:
connection_info[3][key][0] = replacement_name
for consumer in consuming_layers:
for inbound_node in self._inbound_node_generator(consumer):
if isinstance(inbound_node, dict):
inbound_node = inbound_node.values()
for connection_info in inbound_node:
_replace_layer_name_for_connection_info(connection_info, match_name,
replacement_name)
output_consumers = self._get_output_consumers(match_layer_node.layer)
for output_consumer in output_consumers:
output_consumer[0] = replacement_layer_node.layer['config']['name']
# 2. Create inbound nodes for the replacement layers. This connects all
# the replacement layers.
def _assign_inbounds_for_replacement(layer_node):
"""_assign_inbounds_for_replacement."""
if not layer_node.input_layers:
return
layer_node.layer['inbound_nodes'] = [[]]
for input_layer in layer_node.input_layers:
# inbound_nodes can be specific tensors from multiple inbound
# connections. We make the following assumptions.
# - Only 1 inbound node for each replacement layer.
# - Only 1 tensor output from the previous layer which is connected.
# - call() method during construction does not have any args.
# These are reasonable assumptions for almost all case we are
# interested in.
layer_node.layer['inbound_nodes'][0].append(
[input_layer.layer['config']['name'], 0, 0, {}])
_assign_inbounds_for_replacement(input_layer)
_assign_inbounds_for_replacement(replacement_layer_node)
# 3. Connect the leaves of the replacement_layers to the inbound_nodes of
# the leaves in the original layer.
original_leaf_layers = self._get_leaf_layers(match_layer_node)
original_inbound_nodes = [
layer['inbound_nodes'] for layer in original_leaf_layers
]
replacement_leaf_layers = self._get_leaf_layers(replacement_layer_node)
# The original pattern and the replacement pattern can potentially have
# different number of leaf nodes and differences in how they consume these
# input layers. Matching them will require sophisticated hackery to recreate
# the new layers with the original input structure.
# Given our existing transforms, we can assume they match.
if len(original_leaf_layers) != len(replacement_leaf_layers):
raise RuntimeError('Different size of leaf layers not supported yet.')
for original_inbound_nodes, replacement_leaf_layer in zip(
original_inbound_nodes, replacement_leaf_layers):
replacement_leaf_layer['inbound_nodes'] = original_inbound_nodes
# 4. Remove the original matched layers
layers_to_remove_names = self._get_layer_names(match_layer_node)
layers_to_remove = self._get_layers(layers_to_remove_names)
self._remove_layers(layers_to_remove, layers_to_remove_names)
# 5. Add in the new layers.
def _add_replacement_layer(layer_node):
"""Recursively add new layers."""
self._config['layers'].append(layer_node.layer)
layer_name = layer_node.layer['config']['name']
# TODO(b/184603494): Remove weight map structure from model_transformer.
if layer_node.weights:
self._layer_weights_map[layer_name] = layer_node.weights
if layer_node.names_and_weights:
self._layer_names_and_weights_map[
layer_name] = layer_node.names_and_weights
if layer_node.metadata:
self._layer_metadata_map[layer_name] = layer_node.metadata
if self.candidate_layers:
self.candidate_layers.add(layer_name)
for input_layer in layer_node.input_layers:
_add_replacement_layer(input_layer)
_add_replacement_layer(replacement_layer_node)