def import_graph()

in tensorwatch/model_graph/hiddenlayer/pytorch_builder.py [0:0]


def import_graph(hl_graph, model, args, input_names=None, verbose=False):
    # TODO: add input names to graph

    if args is None:
        args = [1, 3, 224, 224] # assume ImageNet default

    # if args is not Tensor but is array like then convert it to torch tensor
    if not isinstance(args, torch.Tensor) and \
        hasattr(args, "__len__") and hasattr(args, '__getitem__') and \
        not isinstance(args, (str, abc.ByteString)):
        args = torch.ones(args)

    # Run the Pytorch graph to get a trace and generate a graph from it
    with torch.onnx.set_training(model, False):
        try:
            trace = torch.jit.trace(model, args)
            torch.onnx._optimize_trace(trace)
            torch_graph = trace.graph
        except RuntimeError as e:
            print(e)
            print('Error occured when creating jit trace for model.')
            raise e

    # Dump list of nodes (DEBUG only)
    if verbose:
        dump_pytorch_graph(torch_graph)

    # Loop through nodes and build HL graph
    nodes = list(torch_graph.nodes())
    inps = [(n, [i.unique() for i in n.inputs()]) for n in nodes]
    for i, torch_node in enumerate(nodes):
        # Op
        op = torch_node.kind()
        # Parameters
        params = {k: torch_node[k] for k in torch_node.attributeNames()} 
        # Inputs/outputs
        # TODO: inputs = [i.unique() for i in node.inputs()]
        outputs = [o.unique() for o in torch_node.outputs()]
        # Get output shape
        shape = get_shape(torch_node)
        # Add HL node
        hl_node = Node(uid=pytorch_id(torch_node), name=None, op=op, 
                       output_shape=shape, params=params)
        hl_graph.add_node(hl_node)
        # Add edges
        for target_torch_node,target_inputs in inps:
            if set(outputs) & set(target_inputs):
                hl_graph.add_edge_by_id(pytorch_id(torch_node), pytorch_id(target_torch_node), shape)
    return hl_graph