def _replace_functional()

in tensorflow_model_optimization/python/core/quantization/keras/graph_transformations/model_transformer.py [0:0]


  def _replace_functional(self, match_layer_node, replacement_layer_node):
    """Functional model: replace the tree of match_layer_node with replacement_layer_node."""

    # 1. Point all consumers of the head of the matching sub-tree to the head
    # replacement layer.
    #
    # There are some assumptions baked in. The head layer only has 1 inbound and
    # outbound node. The resulting number and shape of tensors from the
    # replaced layer should equal the original layer.

    consuming_layers = self._get_consuming_layers(match_layer_node.layer)
    match_name = match_layer_node.layer['config']['name']
    replacement_name = replacement_layer_node.layer['config']['name']

    def _replace_layer_name_for_connection_info(connection_info, match_name,
                                                replacement_name):
      if connection_info[0] == match_name:
        connection_info[0] = replacement_name
      for key in connection_info[3]:
        if isinstance(connection_info[3][key], list):
          if connection_info[3][key][0] == match_name:
            connection_info[3][key][0] = replacement_name

    for consumer in consuming_layers:
      for inbound_node in self._inbound_node_generator(consumer):
        if isinstance(inbound_node, dict):
          inbound_node = inbound_node.values()
        for connection_info in inbound_node:
          _replace_layer_name_for_connection_info(connection_info, match_name,
                                                  replacement_name)

    output_consumers = self._get_output_consumers(match_layer_node.layer)
    for output_consumer in output_consumers:
      output_consumer[0] = replacement_layer_node.layer['config']['name']

    # 2. Create inbound nodes for the replacement layers. This connects all
    # the replacement layers.

    def _assign_inbounds_for_replacement(layer_node):
      """_assign_inbounds_for_replacement."""

      if not layer_node.input_layers:
        return

      layer_node.layer['inbound_nodes'] = [[]]
      for input_layer in layer_node.input_layers:
        # inbound_nodes can be specific tensors from multiple inbound
        # connections. We make the following assumptions.
        # - Only 1 inbound node for each replacement layer.
        # - Only 1 tensor output from the previous layer which is connected.
        # - call() method during construction does not have any args.
        # These are reasonable assumptions for almost all case we are
        # interested in.
        layer_node.layer['inbound_nodes'][0].append(
            [input_layer.layer['config']['name'], 0, 0, {}])

        _assign_inbounds_for_replacement(input_layer)

    _assign_inbounds_for_replacement(replacement_layer_node)

    # 3. Connect the leaves of the replacement_layers to the inbound_nodes of
    # the leaves in the original layer.

    original_leaf_layers = self._get_leaf_layers(match_layer_node)
    original_inbound_nodes = [
        layer['inbound_nodes'] for layer in original_leaf_layers
    ]

    replacement_leaf_layers = self._get_leaf_layers(replacement_layer_node)

    # The original pattern and the replacement pattern can potentially have
    # different number of leaf nodes and differences in how they consume these
    # input layers. Matching them will require sophisticated hackery to recreate
    # the new layers with the original input structure.

    # Given our existing transforms, we can assume they match.

    if len(original_leaf_layers) != len(replacement_leaf_layers):
      raise RuntimeError('Different size of leaf layers not supported yet.')

    for original_inbound_nodes, replacement_leaf_layer in zip(
        original_inbound_nodes, replacement_leaf_layers):
      replacement_leaf_layer['inbound_nodes'] = original_inbound_nodes

    # 4. Remove the original matched layers
    layers_to_remove_names = self._get_layer_names(match_layer_node)
    layers_to_remove = self._get_layers(layers_to_remove_names)

    self._remove_layers(layers_to_remove, layers_to_remove_names)

    # 5. Add in the new layers.
    def _add_replacement_layer(layer_node):
      """Recursively add new layers."""
      self._config['layers'].append(layer_node.layer)
      layer_name = layer_node.layer['config']['name']
      # TODO(b/184603494): Remove weight map structure from model_transformer.
      if layer_node.weights:
        self._layer_weights_map[layer_name] = layer_node.weights
      if layer_node.names_and_weights:
        self._layer_names_and_weights_map[
            layer_name] = layer_node.names_and_weights
      if layer_node.metadata:
        self._layer_metadata_map[layer_name] = layer_node.metadata
      if self.candidate_layers:
        self.candidate_layers.add(layer_name)

      for input_layer in layer_node.input_layers:
        _add_replacement_layer(input_layer)

    _add_replacement_layer(replacement_layer_node)