in tinynn/graph/tracer.py [0:0]
def add_forward_node(node: TraceNode, input_tensors, output_tensors):
"""Adds a forward node to the current computation graph"""
assert node is not None
if not isinstance(input_tensors, (list, tuple)):
input_tensors = [input_tensors]
need_idx = True
if not isinstance(output_tensors, (list, tuple)):
output_tensors = [output_tensors]
need_idx = False
flatten_inputs = []
for t in input_tensors:
if isinstance(t, (list, tuple)):
for rt in t:
flatten_inputs.append(rt)
else:
flatten_inputs.append(t)
node.prev_tensors.extend(flatten_inputs)
node.next_tensors.extend(output_tensors)
for i, t in enumerate(flatten_inputs):
assert type(t) in (
torch.dtype,
torch.device,
torch.Size,
torch.Tensor,
torch.nn.Parameter,
torch.nn.quantized.FloatFunctional,
) or isinstance(t, torch.nn.Module), (
f'Input #{i} of {node.unique_name}({node.type()}) should be one of the following type '
' [torch.dtype, torch.device, torch.Size, torch.Tensor,'
f' torch.nn.Parameter,torch.nn.quantized.FloatFunctional, torch.nn.Module], but got {type(t)}'
)
constant_handler(t, node.unique_name, node.full_name())
pre_node_name = current_graph().tensor_pre_node_dict[id(t)]
node.prev_nodes.append(current_graph().nodes_map[pre_node_name])
if id(t) in current_graph().tensor_pre_index_dict:
pre_node_index = current_graph().tensor_pre_index_dict[id(t)]
log.debug(f'propagate pre_index tensor {pre_node_name} {pre_node_index}')
node.prev_indices.append(pre_node_index)
else:
node.prev_indices.append(None)
if isinstance(t, torch.nn.Parameter):
if id(t) in current_graph().tensor_parameter_dict:
node.prev_tensors[i] = current_graph().tensor_parameter_dict[id(t)]()
else:
node.prev_tensors[i] = node.prev_tensors[i].data
current_graph().tensor_parameter_dict[id(t)] = weakref.ref(node.prev_tensors[i])
for i, t in enumerate(output_tensors):
if isinstance(t, (list, tuple)):
for j, rt in enumerate(t):
assert type(rt) in (torch.dtype, torch.device, torch.Size, torch.Tensor, torch.nn.Parameter), (
f'Output [{i}][{j}] of {node.unique_name}({node.type()}) should be one of the following type '
f' [torch.dtype, torch.device, torch.Size, torch.Tensor], but got {type(rt)}'
)
current_graph().tensor_pre_node_dict[id(rt)] = node.unique_name
if need_idx:
log.debug(f'set pre_index tensor {i}, {j}')
current_graph().tensor_pre_index_dict[id(rt)] = [i, j]
else:
assert type(t) in (torch.dtype, torch.device, torch.Size, torch.Tensor, torch.nn.Parameter), (
f'Output #{i} of {node.unique_name}({node.type()}) should be one of the following type '
f' [torch.dtype, torch.device, torch.Size, torch.Tensor], but got {type(t)}'
)
current_graph().tensor_pre_node_dict[id(t)] = node.unique_name
if need_idx:
log.debug(f'set pre_index tensor {i}')
current_graph().tensor_pre_index_dict[id(t)] = i
if isinstance(t, torch.nn.Parameter):
if id(t) in current_graph().tensor_parameter_dict:
node.next_tensors[i] = current_graph().tensor_parameter_dict[id(t)]()
else:
node.next_tensors[i] = node.next_tensors[i].data
current_graph().tensor_parameter_dict[id(t)] = weakref.ref(node.next_tensors[i])
current_graph().forward_nodes.append(node)
current_graph().nodes_map[node.unique_name] = node