tensorboardX/tensorboardX/pytorch_graph.py (187 lines of code) (raw):
import logging
import time
from collections import OrderedDict
from .proto.attr_value_pb2 import AttrValue
from .proto.graph_pb2 import GraphDef
from .proto.node_def_pb2 import NodeDef
from .proto.step_stats_pb2 import RunMetadata, StepStats, DeviceStepStats, NodeExecStats, AllocatorMemoryUsed
from .proto.tensor_shape_pb2 import TensorShapeProto
from .proto.versions_pb2 import VersionDef
from .proto_graph import node_proto
methods_OP = ['attributeNames', 'hasMultipleOutputs', 'hasUses', 'inputs',
'kind', 'outputs', 'outputsSize', 'scopeName']
methods_IO = ['node', 'offset', 'debugName'] # 'unique' <int> , 'type' <Tensor<class 'torch._C.Type'>>
backward_mode = False
class NodeBase(object):
def __init__(self,
debugName=None,
inputs=None,
scope=None,
tensor_size=None,
op_type='UnSpecified',
attributes=''):
self.debugName = debugName
self.inputs = inputs
self.tensor_size = tensor_size
self.kind = op_type
self.attributes = attributes
if scope is not None:
self.scope = scope
def __repr__(self):
repr = []
repr.append(str(type(self)))
for m in dir(self):
if '__' not in m:
repr.append(m + ': ' + str(getattr(self, m)) + str(type(getattr(self, m))))
return '\n'.join(repr) + '\n\n'
class NodePy(NodeBase):
def __init__(self, node_cpp, valid_methods):
super(NodePy, self).__init__(node_cpp)
valid_methods = valid_methods[:]
self.inputs = []
global backward_mode
for m in valid_methods:
if m == 'inputs' or m == 'outputs':
list_of_node = list(getattr(node_cpp, m)())
io_unique_names = []
io_tensor_sizes = []
for n in list_of_node:
if backward_mode:
io_unique_names.append(n.uniqueName())
else:
io_unique_names.append(n.debugName())
if n.type().kind() == 'CompleteTensorType':
io_tensor_sizes.append(n.type().sizes())
else:
io_tensor_sizes.append(None)
setattr(self, m, io_unique_names)
setattr(self, m + 'tensor_size', io_tensor_sizes)
else:
if m == 'debugName' and backward_mode:
setattr(self, m, getattr(node_cpp, 'uniqueName')())
else:
setattr(self, m, getattr(node_cpp, m)())
class NodePyIO(NodePy):
def __init__(self, node_cpp, input_or_output=None):
super(NodePyIO, self).__init__(node_cpp, methods_IO)
try:
tensor_size = node_cpp.type().sizes()
except RuntimeError:
tensor_size = [1, ] # fail when constant model is used.
self.tensor_size = tensor_size
# Kind attribute string is purely descriptive and will be shown
# in detailed information for the node in TensorBoard's graph plugin.
#
# NodePyOP nodes get this from their kind() method.
self.kind = 'Parameter'
if input_or_output:
self.input_or_output = input_or_output
self.kind = 'IO Node'
class NodePyOP(NodePy):
def __init__(self, node_cpp):
super(NodePyOP, self).__init__(node_cpp, methods_OP)
# Replace single quote which causes strange behavior in TensorBoard
# TODO: See if we can remove this in the future
self.attributes = str({k: node_cpp[k] for k in node_cpp.attributeNames()}).replace("'", ' ')
self.kind = node_cpp.kind()
class GraphPy(object):
"""Helper class to convert torch.nn.Module to GraphDef proto and visualization
with TensorBoard.
GraphDef generation operates in two passes:
In the first pass, all nodes are read and saved to two lists.
One list is for input/output nodes (nodes_io), which only have inbound
or outbound connections, but not both. Another list is for internal
operator nodes (nodes_op). The first pass also saves all scope name
appeared in the nodes in scope_name_appeared list for later processing.
In the second pass, scope names are fully applied to all nodes.
debugNameToScopedName is a mapping from a node's ID to its fully qualified
scope name. e.g. Net1/Linear[0]/1. Unfortunately torch.jit doesn't have
totally correct scope output, so this is nontrivial. The function
populate_namespace_from_OP_to_IO and find_common_root are used to
assign scope name to a node based on the connection between nodes
in a heuristic kind of way. Bookkeeping is done with shallowest_scope_name
and scope_name_appeared.
"""
def __init__(self):
self.nodes_op = []
self.nodes_io = OrderedDict()
self.unique_name_to_scoped_name = {}
self.shallowest_scope_name = 'default'
self.scope_name_appeared = []
def append(self, x):
if isinstance(x, NodePyIO):
self.nodes_io[x.debugName] = x
if isinstance(x, NodePyOP):
self.nodes_op.append(x)
for node_output, outputSize in zip(x.outputs, x.outputstensor_size):
self.scope_name_appeared.append(x.scopeName)
self.nodes_io[node_output] = NodeBase(node_output,
x.inputs,
x.scopeName,
outputSize,
op_type=x.kind,
attributes=x.attributes)
def printall(self):
print('all nodes')
for node in self.nodes_op:
print(node)
for key in self.nodes_io:
print(self.nodes_io[key])
def find_common_root(self):
for fullscope in self.scope_name_appeared:
if fullscope:
self.shallowest_scope_name = fullscope.split('/')[0]
def populate_namespace_from_OP_to_IO(self):
for node in self.nodes_op:
for input_node_id in node.inputs:
self.unique_name_to_scoped_name[input_node_id] = node.scopeName + '/' + input_node_id
for key, node in self.nodes_io.items():
if type(node) == NodeBase:
self.unique_name_to_scoped_name[key] = node.scope + '/' + node.debugName
if hasattr(node, 'input_or_output'):
self.unique_name_to_scoped_name[key] = node.input_or_output + '/' + node.debugName
if hasattr(node, 'scope'):
if node.scope == '' and self.shallowest_scope_name:
self.unique_name_to_scoped_name[node.debugName] = \
self.shallowest_scope_name + '/' + node.debugName
# replace name
for key, node in self.nodes_io.items():
self.nodes_io[key].inputs = \
[self.unique_name_to_scoped_name[node_input_id] for node_input_id in node.inputs]
if node.debugName in self.unique_name_to_scoped_name:
self.nodes_io[key].debugName = self.unique_name_to_scoped_name[node.debugName]
def to_proto(self):
"""
Converts graph representation of GraphPy object to TensorBoard
required format.
"""
# TODO: compute correct memory usage and CPU time once
# PyTorch supports it
import numpy as np
nodes = []
node_stats = []
for v in self.nodes_io.values():
nodes.append(node_proto(v.debugName,
input=v.inputs,
outputsize=v.tensor_size,
op=v.kind,
attributes=v.attributes))
if v.tensor_size and len(v.tensor_size) > 0: # assume data is float32, only parameter is counted
node_stats.append(
NodeExecStats(node_name=v.debugName,
all_start_micros=int(time.time() * 1e7),
all_end_rel_micros=42,
memory=[AllocatorMemoryUsed(allocator_name="cpu",
total_bytes=int(np.prod(v.tensor_size)) * 4)]))
return nodes, node_stats
# one argument: 'hasAttribute', 'hasAttributes',
def parse(graph, args=None, omit_useless_nodes=True):
"""This method parses an optimized PyTorch model graph and produces
a list of nodes and node stats for eventual conversion to TensorBoard
protobuf format.
Args:
graph (PyTorch module): The model to be parsed.
args (tuple): input tensor[s] for the model.
omit_useless_nodes (boolean): Whether to remove nodes from the graph.
"""
import torch
n_inputs = len(args) # not sure...
nodes_py = GraphPy()
for i, node in enumerate(graph.inputs()):
global backward_mode
if not backward_mode:
try:
node.debugName()
except:
backward_mode = True
if omit_useless_nodes:
if len(node.uses()) == 0: # number of user of the node (= number of outputs/ fanout)
continue
if i < n_inputs:
nodes_py.append(NodePyIO(node, 'input'))
else:
nodes_py.append(NodePyIO(node)) # parameter
for node in graph.nodes():
nodes_py.append(NodePyOP(node))
for node in graph.outputs(): # must place last.
NodePyIO(node, 'output')
nodes_py.find_common_root()
nodes_py.populate_namespace_from_OP_to_IO()
return nodes_py.to_proto()
def graph(model, args, verbose=False, **kwargs):
"""
This method processes a PyTorch model and produces a `GraphDef` proto
that can be logged to TensorBoard.
Args:
model (PyTorch module): The model to be parsed.
args (tuple): input tensor[s] for the model.
verbose (bool): Whether to print out verbose information while
processing.
"""
import torch
with torch.onnx.set_training(model, False): # TODO: move outside of torch.onnx
try:
trace = torch.jit.trace(model, args)
graph = trace.graph
except RuntimeError as e:
print(e)
print('Error occurs, No graph saved')
raise e
# Create an object matching
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/graph.proto
# The producer version has been reverse engineered from standard
# TensorBoard logged data.
if verbose:
print(graph)
list_of_nodes, node_stats = parse(graph, args)
# We are hardcoding that this was run on CPU even though it might have actually
# run on GPU. Note this is what is shown in TensorBoard and has no bearing
# on actual execution.
# TODO: See if we can extract GPU vs CPU information from the PyTorch model
# and pass it correctly to TensorBoard.
#
# Definition of StepStats and DeviceStepStats can be found at
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/plugins/graph/tf_graph_common/test/graph-test.ts
# and
# https://github.com/tensorflow/tensorboard/blob/master/tensorboard/compat/proto/step_stats.proto
stepstats = RunMetadata(step_stats=StepStats(dev_stats=[DeviceStepStats(device="/device:CPU:0",
node_stats=node_stats)]))
return GraphDef(node=list_of_nodes, versions=VersionDef(producer=22)), stepstats