in tensorboardX/tensorboardX/onnx_graph.py [0:0]
def parse(graph):
nodes_proto = []
nodes = []
import itertools
for node in itertools.chain(graph.input, graph.output):
nodes_proto.append(node)
for node in nodes_proto:
print(node.name)
shapeproto = TensorShapeProto(
dim=[TensorShapeProto.Dim(size=d.dim_value) for d in node.type.tensor_type.shape.dim])
nodes.append(NodeDef(
name=node.name.encode(encoding='utf_8'),
op='Variable',
input=[],
attr={
'dtype': AttrValue(type=node.type.tensor_type.elem_type),
'shape': AttrValue(shape=shapeproto),
})
)
for node in graph.node:
attr = []
for s in node.attribute:
attr.append(' = '.join([str(f[1]) for f in s.ListFields()]))
attr = ', '.join(attr).encode(encoding='utf_8')
print(node.output[0])
nodes.append(NodeDef(
name=node.output[0].encode(encoding='utf_8'),
op=node.op_type,
input=node.input,
attr={'parameters': AttrValue(s=attr)},
))
# two pass token replacement, appends opname to object id
mapping = {}
for node in nodes:
mapping[node.name] = node.op + '_' + node.name
return GraphDef(node=nodes, versions=VersionDef(producer=22))