in tensorwatch/model_graph/hiddenlayer/pytorch_builder_trace.py [0:0]
def import_graph(hl_graph, model, args, input_names=None, verbose=False):
# TODO: add input names to graph
if args is None:
args = [1, 3, 224, 224] # assume ImageNet default
# if args is not Tensor but is array like then convert it to torch tensor
if not isinstance(args, torch.Tensor) and \
hasattr(args, "__len__") and hasattr(args, '__getitem__') and \
not isinstance(args, (str, abc.ByteString)):
args = torch.ones(args)
graph_py = graph(model, args, verbose)
# # Loop through nodes and build HL graph
# nodes = list(torch_graph.nodes())
# inps = [(n, [i.unique() for i in n.inputs()]) for n in nodes]
# for i, torch_node in enumerate(nodes):
# # Op
# op = torch_node.kind()
# # Parameters
# params = {k: torch_node[k] for k in torch_node.attributeNames()}
# # Inputs/outputs
# # TODO: inputs = [i.unique() for i in node.inputs()]
# outputs = [o.unique() for o in torch_node.outputs()]
# # Get output shape
# shape = get_shape(torch_node)
# # Add HL node
# hl_node = Node(uid=pytorch_id(torch_node), name=None, op=op,
# output_shape=shape, params=params)
# hl_graph.add_node(hl_node)
# # Add edges
# for target_torch_node,target_inputs in inps:
# if set(outputs) & set(target_inputs):
# hl_graph.add_edge_by_id(pytorch_id(torch_node), pytorch_id(target_torch_node), shape)
return hl_graph