in tinynn/graph/tracer.py [0:0]
def insert_after(self, node: TraceNode, module, next_tensors: typing.Optional[typing.List[torch.Tensor]] = None):
"""Insert a module or an existing node after a node in the computation graph"""
# Create a new node and connects it to the next node/tensors
if type(module) is not TraceNode:
new_node = TraceNode(module, cur_graph=self)
if node in self.input_nodes or node in self.constant_nodes:
self.forward_nodes.insert(0, new_node)
elif node in self.output_nodes:
log.error('You cannot insert a node after output nodes')
assert False
else:
idx = self.forward_nodes.index(node)
self.forward_nodes.insert(idx + 1, new_node)
self.nodes_map[new_node.unique_name] = new_node
else:
new_node = module
is_constant_node = type(node.module) in (ConstantNode, torch.nn.quantized.FloatFunctional)
new_node.prev_nodes.append(node)
new_node.next_nodes.extend(node.next_nodes)
if next_tensors is None:
next_tensors = [None] * len(node.next_tensors)
for t, new_t in zip(node.next_tensors, next_tensors):
if new_t is None:
new_t = t.clone()
self.tensor_pre_node_dict[id(new_t)] = new_node.unique_name
new_node.prev_tensors.append(t)
new_node.next_tensors.append(new_t)
new_node.prev_indices.append(None)
# Make input/constant nodes connects to the new node
node.next_nodes.clear()
node.next_nodes.append(new_node)
# Connect the next nodes to the new node
tensor_replace_dict = dict(zip(new_node.prev_tensors, new_node.next_tensors))
for next_node in new_node.next_nodes:
is_next_constant_node = type(next_node) in (ConstantNode, torch.nn.quantized.FloatFunctional)
for i, n in enumerate(next_node.prev_nodes):
if n == node:
next_node.prev_nodes[i] = new_node
# Make sure the data is writable
if isinstance(next_node.prev_tensors, tuple):
next_node.prev_tensors = list(next_node.prev_tensors)
updated_indices = []
for i, t in enumerate(next_node.prev_tensors):
if t in tensor_replace_dict:
next_node.prev_tensors[i] = tensor_replace_dict[t]
updated_indices.append(next_node.prev_indices[i])
# Since the function calls are rendered beforehand,
# we need to change them as well.
if type(next_node.module) is TraceFunction:
if next_node.module.args_string is not None:
for idx in updated_indices:
old_unique_name = tensor_name_from_parts(node.unique_name, idx, is_constant_node)
new_unique_name = tensor_name_from_parts(new_node.unique_name, idx, is_next_constant_node)
next_node.module.replace_tensor_name(old_unique_name, new_unique_name)
next_node.module.update_args_string()