def add_forward_node()

in tinynn/graph/tracer.py [0:0]


def add_forward_node(node: TraceNode, input_tensors, output_tensors):
    """Adds a forward node to the current computation graph"""
    assert node is not None
    if not isinstance(input_tensors, (list, tuple)):
        input_tensors = [input_tensors]

    need_idx = True
    if not isinstance(output_tensors, (list, tuple)):
        output_tensors = [output_tensors]
        need_idx = False

    flatten_inputs = []
    for t in input_tensors:
        if isinstance(t, (list, tuple)):
            for rt in t:
                flatten_inputs.append(rt)
        else:
            flatten_inputs.append(t)

    node.prev_tensors.extend(flatten_inputs)
    node.next_tensors.extend(output_tensors)

    for i, t in enumerate(flatten_inputs):
        assert type(t) in (
            torch.dtype,
            torch.device,
            torch.Size,
            torch.Tensor,
            torch.nn.Parameter,
            torch.nn.quantized.FloatFunctional,
        ) or isinstance(t, torch.nn.Module), (
            f'Input #{i} of {node.unique_name}({node.type()}) should be one of the following type            '
            ' [torch.dtype, torch.device, torch.Size, torch.Tensor,'
            f' torch.nn.Parameter,torch.nn.quantized.FloatFunctional, torch.nn.Module], but got {type(t)}'
        )
        constant_handler(t, node.unique_name, node.full_name())
        pre_node_name = current_graph().tensor_pre_node_dict[id(t)]
        node.prev_nodes.append(current_graph().nodes_map[pre_node_name])
        if id(t) in current_graph().tensor_pre_index_dict:
            pre_node_index = current_graph().tensor_pre_index_dict[id(t)]
            log.debug(f'propagate pre_index tensor {pre_node_name} {pre_node_index}')
            node.prev_indices.append(pre_node_index)
        else:
            node.prev_indices.append(None)
        if isinstance(t, torch.nn.Parameter):
            if id(t) in current_graph().tensor_parameter_dict:
                node.prev_tensors[i] = current_graph().tensor_parameter_dict[id(t)]()
            else:
                node.prev_tensors[i] = node.prev_tensors[i].data
                current_graph().tensor_parameter_dict[id(t)] = weakref.ref(node.prev_tensors[i])

    for i, t in enumerate(output_tensors):
        if isinstance(t, (list, tuple)):
            for j, rt in enumerate(t):
                assert type(rt) in (torch.dtype, torch.device, torch.Size, torch.Tensor, torch.nn.Parameter), (
                    f'Output [{i}][{j}] of {node.unique_name}({node.type()}) should be one of the following type       '
                    f'              [torch.dtype, torch.device, torch.Size, torch.Tensor], but got {type(rt)}'
                )
                current_graph().tensor_pre_node_dict[id(rt)] = node.unique_name
                if need_idx:
                    log.debug(f'set pre_index tensor {i}, {j}')
                    current_graph().tensor_pre_index_dict[id(rt)] = [i, j]
        else:
            assert type(t) in (torch.dtype, torch.device, torch.Size, torch.Tensor, torch.nn.Parameter), (
                f'Output #{i} of {node.unique_name}({node.type()}) should be one of the following type                '
                f' [torch.dtype, torch.device, torch.Size, torch.Tensor], but got {type(t)}'
            )
            current_graph().tensor_pre_node_dict[id(t)] = node.unique_name
            if need_idx:
                log.debug(f'set pre_index tensor {i}')
                current_graph().tensor_pre_index_dict[id(t)] = i
            if isinstance(t, torch.nn.Parameter):
                if id(t) in current_graph().tensor_parameter_dict:
                    node.next_tensors[i] = current_graph().tensor_parameter_dict[id(t)]()
                else:
                    node.next_tensors[i] = node.next_tensors[i].data
                    current_graph().tensor_parameter_dict[id(t)] = weakref.ref(node.next_tensors[i])

    current_graph().forward_nodes.append(node)
    current_graph().nodes_map[node.unique_name] = node