def _match_layer_with_inputs()

in tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer.py [0:0]


  def _match_layer_with_inputs(self, layer, pattern, is_head_node):
    """Match pattern at this layer, and continue to match at its inputs."""

    if not self._match_layer(layer, pattern):
      return None

    if self._is_functional_model(
        self.model) and not self._is_match_supported(layer, is_head_node):
      return None

    if len(pattern.inputs) == 0:
      # Leaf layer in pattern.
      return LayerNode(
          layer, self._get_layer_weights(layer['config']['name']), [],
          self._get_layer_metadata(layer['config']['name']),
          self._get_layer_names_and_weights(layer['config']['name']))

    # There is a possible edge case where a single layer may output multiple
    # tensors and multiple tensors from that layer may be used by the
    # connection. Ignoring those for now.
    input_layer_names = self._get_input_layer_names(layer)
    input_layers = self._get_layers(input_layer_names)

    if len(input_layers) != len(pattern.inputs):
      # Number of inputs this layer takes is different from the number of
      # inputs in the pattern.
      #
      # This path currently has the limitation that it requires an exact number
      # of inputs to match a pattern. For example, if a user wants to match
      # 2 Convs -> Concat and 3 Convs -> Concat, they would need to write
      # 2 different patterns.
      return None

    # Inbound layers can have different order from the list of input patterns.
    # TODO(pulkitb): Fix by checking all permutations.
    input_match_layer_nodes = []
    for input_layer, pattern_ in zip(input_layers, pattern.inputs):
      match_layer_node = self._match_layer_with_inputs(
          input_layer, pattern_, is_head_node=False)
      if not match_layer_node:
        return None
      input_match_layer_nodes.append(match_layer_node)

    return LayerNode(layer, self._get_layer_weights(layer['config']['name']),
                     input_match_layer_nodes,
                     self._get_layer_metadata(layer['config']['name']),
                     self._get_layer_names_and_weights(layer['config']['name']))