in smdebug/mxnet/graph.py [0:0]
def _get_nodes_from_symbol(sym):
"""Given a symbol and shapes, return a list of `NodeDef`s for visualizing the
the graph in TensorBoard."""
if not isinstance(sym, Symbol):
raise TypeError(
"sym must be an `mxnet.symbol.Symbol`," " received type {}".format(str(type(sym)))
)
conf = json.loads(sym.tojson())
nodes = conf["nodes"]
data2op = {} # key: data id, value: list of ops to whom data is an input
for i, node in enumerate(nodes):
if node["op"] != "null": # node is an operator
input_list = node["inputs"]
for idx in input_list:
if idx[0] == 0: # do not include 'data' node in the op scope
continue
if idx[0] in data2op:
# nodes[idx[0]] is a data as an input to op nodes[i]
data2op[idx[0]].append(i)
else:
data2op[idx[0]] = [i]
# In the following, we group data with operators they belong to
# by attaching them with operator names as scope names.
# The parameters with the operator name as the prefix will be
# assigned with the scope name of that operator. For example,
# a convolution op has name 'conv', while its weight and bias
# have name 'conv_weight' and 'conv_bias'. In the end, the operator
# has scope name 'conv' prepended to its name, i.e. 'conv/conv'.
# The parameters are named 'conv/conv_weight' and 'conv/conv_bias'.
node_defs = []
for i, node in enumerate(nodes):
node_name = node["name"]
op_name = node["op"]
kwargs = {"op": op_name, "name": node_name}
if op_name != "null": # node is an operator
inputs = []
input_list = node["inputs"]
for idx in input_list:
input_node = nodes[idx[0]]
input_node_name = input_node["name"]
if input_node["op"] != "null":
inputs.append(_scoped_name(input_node_name, input_node_name))
elif idx[0] in data2op and len(data2op[idx[0]]) == 1 and data2op[idx[0]][0] == i:
# the data is only as an input to nodes[i], no else
inputs.append(_scoped_name(node_name, input_node_name))
else: # the data node has no scope name, e.g. 'data' as the input node
inputs.append(input_node_name)
kwargs["input"] = inputs
kwargs["name"] = _scoped_name(node_name, node_name)
elif i in data2op and len(data2op[i]) == 1:
# node is a data node belonging to one op, find out which operator this node belongs to
op_node_name = nodes[data2op[i][0]]["name"]
kwargs["name"] = _scoped_name(op_node_name, node_name)
if "attrs" in node:
# TensorBoard would escape quotation marks, replace it with space
attr = json.dumps(node["attrs"], sort_keys=True).replace('"', " ")
attr = {"param": AttrValue(s=attr.encode(encoding="utf-8"))}
kwargs["attr"] = attr
node_def = NodeDef(**kwargs)
node_defs.append(node_def)
return node_defs