in tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer.py [0:0]
def _match_layer_with_inputs(self, layer, pattern, is_head_node):
"""Match pattern at this layer, and continue to match at its inputs."""
if not self._match_layer(layer, pattern):
return None
if self._is_functional_model(
self.model) and not self._is_match_supported(layer, is_head_node):
return None
if len(pattern.inputs) == 0:
# Leaf layer in pattern.
return LayerNode(
layer, self._get_layer_weights(layer['config']['name']), [],
self._get_layer_metadata(layer['config']['name']),
self._get_layer_names_and_weights(layer['config']['name']))
# There is a possible edge case where a single layer may output multiple
# tensors and multiple tensors from that layer may be used by the
# connection. Ignoring those for now.
input_layer_names = self._get_input_layer_names(layer)
input_layers = self._get_layers(input_layer_names)
if len(input_layers) != len(pattern.inputs):
# Number of inputs this layer takes is different from the number of
# inputs in the pattern.
#
# This path currently has the limitation that it requires an exact number
# of inputs to match a pattern. For example, if a user wants to match
# 2 Convs -> Concat and 3 Convs -> Concat, they would need to write
# 2 different patterns.
return None
# Inbound layers can have different order from the list of input patterns.
# TODO(pulkitb): Fix by checking all permutations.
input_match_layer_nodes = []
for input_layer, pattern_ in zip(input_layers, pattern.inputs):
match_layer_node = self._match_layer_with_inputs(
input_layer, pattern_, is_head_node=False)
if not match_layer_node:
return None
input_match_layer_nodes.append(match_layer_node)
return LayerNode(layer, self._get_layer_weights(layer['config']['name']),
input_match_layer_nodes,
self._get_layer_metadata(layer['config']['name']),
self._get_layer_names_and_weights(layer['config']['name']))