in tensorwatch/model_graph/hiddenlayer/pytorch_builder.py [0:0]
def import_graph(hl_graph, model, args, input_names=None, verbose=False):
# TODO: add input names to graph
if args is None:
args = [1, 3, 224, 224] # assume ImageNet default
# if args is not Tensor but is array like then convert it to torch tensor
if not isinstance(args, torch.Tensor) and \
hasattr(args, "__len__") and hasattr(args, '__getitem__') and \
not isinstance(args, (str, abc.ByteString)):
args = torch.ones(args)
# Run the Pytorch graph to get a trace and generate a graph from it
with torch.onnx.set_training(model, False):
try:
trace = torch.jit.trace(model, args)
torch.onnx._optimize_trace(trace)
torch_graph = trace.graph
except RuntimeError as e:
print(e)
print('Error occured when creating jit trace for model.')
raise e
# Dump list of nodes (DEBUG only)
if verbose:
dump_pytorch_graph(torch_graph)
# Loop through nodes and build HL graph
nodes = list(torch_graph.nodes())
inps = [(n, [i.unique() for i in n.inputs()]) for n in nodes]
for i, torch_node in enumerate(nodes):
# Op
op = torch_node.kind()
# Parameters
params = {k: torch_node[k] for k in torch_node.attributeNames()}
# Inputs/outputs
# TODO: inputs = [i.unique() for i in node.inputs()]
outputs = [o.unique() for o in torch_node.outputs()]
# Get output shape
shape = get_shape(torch_node)
# Add HL node
hl_node = Node(uid=pytorch_id(torch_node), name=None, op=op,
output_shape=shape, params=params)
hl_graph.add_node(hl_node)
# Add edges
for target_torch_node,target_inputs in inps:
if set(outputs) & set(target_inputs):
hl_graph.add_edge_by_id(pytorch_id(torch_node), pytorch_id(target_torch_node), shape)
return hl_graph