in ppuda/deepnets1m/graph.py [0:0]
def visualize(self, node_size=50, figname=None, figsize=None, with_labels=False, vis_legend=False, label_offset=0.001, font_size=10):
r"""
Shows the graphs/legend as in the paper using matplotlib.
:param node_size: node size
:param figname: file name to save the figure in the .pdf and .png formats
:param figsize: (width, height) for a figure
:param with_labels: show node labels (operations)
:param vis_legend: True to only visualize the legend (graph will be ignored)
:param label_offset: positioning of node labels when vis_legend=True
:param font_size: font size for node labels, used only when with_labels=True
:return:
"""
self._nx_graph_from_adj()
# first are conv layers, so that they have a similar color
primitives_order = [2, 3, 4, 10, 5, 6, 11, 12, 13, 0, 1, 14, 7, 8, 9]
assert len(PRIMITIVES_DEEPNETS1M) == len(primitives_order), 'make sure the lists correspond to each other'
n_primitives = len(primitives_order)
color = lambda i: cm.jet(int(np.round(255 * i / n_primitives)))
primitive_colors = { PRIMITIVES_DEEPNETS1M[ind_org] : color(ind_new) for ind_new, ind_org in enumerate(primitives_order) }
# manually adjust some colors for better visualization
primitive_colors['bias'] = '#%02x%02x%02x' % (255, 0, 255)
primitive_colors['msa'] = '#%02x%02x%02x' % (10, 10, 10)
primitive_colors['ln'] = '#%02x%02x%02x' % (255, 255, 0)
node_groups = {'bn': {'style': {'edgecolors': 'k', 'linewidths': 1, 'node_shape': 's'}},
'conv1': {'style': {'edgecolors': 'k', 'linewidths': 1, 'node_shape': '^'}},
'bias': {'style': {'edgecolors': 'gray', 'linewidths': 0.5, 'node_shape': 'd'}},
'pos_enc': {'style': {'edgecolors': 'gray', 'linewidths': 0.5, 'node_shape': 's'}},
'ln': {'style': {'edgecolors': 'gray', 'linewidths': 0.5, 'node_shape': 's'}},
'max_pool': {'style': {'edgecolors': 'k', 'linewidths': 1, 'node_shape': 'o', 'node_size': 1.75 * node_size}},
'glob_avg': {'style': {'edgecolors': 'gray', 'linewidths': 0.5, 'node_shape': 'o', 'node_size': 2 * node_size}},
'concat': {'style': {'edgecolors': 'gray', 'linewidths': 0.5, 'node_shape': '^'}},
'input': {'style': {'edgecolors': 'k', 'linewidths': 1.5, 'node_shape': 's', 'node_size': 2 * node_size}},
'other': {'style': {'edgecolors': 'gray', 'linewidths': 0.5, 'node_shape': 'o'}}}
for group in node_groups:
node_groups[group]['node_lst'] = []
if 'node_size' not in node_groups[group]['style']:
node_groups[group]['style']['node_size'] = node_size
labels, node_colors = {}, []
if vis_legend:
node_feat = torch.cat((torch.tensor([n_primitives]).view(-1, 1),
torch.tensor(primitives_order)[:, None]))
param_shapes = [(3, 3, 1, 1)] + [None] * n_primitives
else:
node_feat = self.node_feat
param_shapes = self._param_shapes
for i, (x, sz) in enumerate(zip(node_feat.view(-1), param_shapes)):
name = PRIMITIVES_DEEPNETS1M[x] if x < n_primitives else 'conv'
labels[i] = name[:20] if x < n_primitives else 'conv_1x1'
node_colors.append(primitive_colors[name])
if name.find('conv') >= 0 and sz is not None and \
((len(sz) == 4 and np.prod(sz[2:]) == 1) or len(sz) == 2):
node_groups['conv1']['node_lst'].append(i)
elif name in node_groups:
node_groups[name]['node_lst'].append(i)
else:
node_groups['other']['node_lst'].append(i)
if vis_legend:
fig = plt.figure(figsize=(20, 3) if figsize is None else figsize)
G = nx.DiGraph(np.diag(np.ones(n_primitives), 1))
pos = {i: (3 * i * node_size, 0) for i in labels }
pos_labels = { i: (x, y - label_offset) for i, (x, y) in pos.items() }
else:
fig = plt.figure(figsize=(10, 10) if figsize is None else figsize)
G = self.nx_graph
pos = nx.drawing.nx_pydot.graphviz_layout(G)
pos_labels = pos
for node_group in node_groups.values():
nx.draw_networkx_nodes(G, pos,
node_color=[node_colors[i] for i in node_group['node_lst']],
nodelist=node_group['node_lst'],
**node_group['style'])
if with_labels:
nx.draw_networkx_labels(G, pos_labels, labels, font_size=font_size)
nx.draw_networkx_edges(G, pos, node_size=node_size,
width=0 if vis_legend else 1,
arrowsize=10,
alpha=0 if vis_legend else 1,
edge_color='white' if vis_legend else 'k',
arrowstyle='-|>')
plt.grid(False)
plt.axis('off')
if figname is not None:
plt.savefig(figname + '.pdf', dpi=fig.dpi)
plt.savefig(figname + '.png', dpi=fig.dpi, transparent=True)
plt.show()