def insert_after()

in tinynn/graph/tracer.py [0:0]


    def insert_after(self, node: TraceNode, module, next_tensors: typing.Optional[typing.List[torch.Tensor]] = None):
        """Insert a module or an existing node after a node in the computation graph"""
        # Create a new node and connects it to the next node/tensors
        if type(module) is not TraceNode:
            new_node = TraceNode(module, cur_graph=self)
            if node in self.input_nodes or node in self.constant_nodes:
                self.forward_nodes.insert(0, new_node)
            elif node in self.output_nodes:
                log.error('You cannot insert a node after output nodes')
                assert False
            else:
                idx = self.forward_nodes.index(node)
                self.forward_nodes.insert(idx + 1, new_node)
            self.nodes_map[new_node.unique_name] = new_node
        else:
            new_node = module

        is_constant_node = type(node.module) in (ConstantNode, torch.nn.quantized.FloatFunctional)

        new_node.prev_nodes.append(node)
        new_node.next_nodes.extend(node.next_nodes)
        if next_tensors is None:
            next_tensors = [None] * len(node.next_tensors)
        for t, new_t in zip(node.next_tensors, next_tensors):
            if new_t is None:
                new_t = t.clone()
            self.tensor_pre_node_dict[id(new_t)] = new_node.unique_name
            new_node.prev_tensors.append(t)
            new_node.next_tensors.append(new_t)
            new_node.prev_indices.append(None)

        # Make input/constant nodes connects to the new node
        node.next_nodes.clear()
        node.next_nodes.append(new_node)

        # Connect the next nodes to the new node
        tensor_replace_dict = dict(zip(new_node.prev_tensors, new_node.next_tensors))
        for next_node in new_node.next_nodes:
            is_next_constant_node = type(next_node) in (ConstantNode, torch.nn.quantized.FloatFunctional)
            for i, n in enumerate(next_node.prev_nodes):
                if n == node:
                    next_node.prev_nodes[i] = new_node

            # Make sure the data is writable
            if isinstance(next_node.prev_tensors, tuple):
                next_node.prev_tensors = list(next_node.prev_tensors)

            updated_indices = []
            for i, t in enumerate(next_node.prev_tensors):
                if t in tensor_replace_dict:
                    next_node.prev_tensors[i] = tensor_replace_dict[t]
                    updated_indices.append(next_node.prev_indices[i])

            # Since the function calls are rendered beforehand,
            # we need to change them as well.
            if type(next_node.module) is TraceFunction:
                if next_node.module.args_string is not None:
                    for idx in updated_indices:
                        old_unique_name = tensor_name_from_parts(node.unique_name, idx, is_constant_node)
                        new_unique_name = tensor_name_from_parts(new_node.unique_name, idx, is_next_constant_node)
                        next_node.module.replace_tensor_name(old_unique_name, new_unique_name)
                        next_node.module.update_args_string()