def visualize()

in ppuda/deepnets1m/graph.py [0:0]


    def visualize(self, node_size=50, figname=None, figsize=None, with_labels=False, vis_legend=False, label_offset=0.001, font_size=10):
        r"""
        Shows the graphs/legend as in the paper using matplotlib.
        :param node_size: node size
        :param figname: file name to save the figure in the .pdf and .png formats
        :param figsize: (width, height) for a figure
        :param with_labels: show node labels (operations)
        :param vis_legend: True to only visualize the legend (graph will be ignored)
        :param label_offset: positioning of node labels when vis_legend=True
        :param font_size: font size for node labels, used only when with_labels=True
        :return:
        """

        self._nx_graph_from_adj()

        # first are conv layers, so that they have a similar color
        primitives_order = [2, 3, 4, 10, 5, 6, 11, 12, 13, 0, 1, 14, 7, 8, 9]
        assert len(PRIMITIVES_DEEPNETS1M) == len(primitives_order), 'make sure the lists correspond to each other'

        n_primitives = len(primitives_order)
        color = lambda i: cm.jet(int(np.round(255 * i / n_primitives)))
        primitive_colors = { PRIMITIVES_DEEPNETS1M[ind_org] : color(ind_new)  for ind_new, ind_org in enumerate(primitives_order) }
        # manually adjust some colors for better visualization
        primitive_colors['bias'] = '#%02x%02x%02x' % (255, 0, 255)
        primitive_colors['msa'] = '#%02x%02x%02x' % (10, 10, 10)
        primitive_colors['ln'] = '#%02x%02x%02x' % (255, 255, 0)

        node_groups = {'bn':        {'style': {'edgecolors': 'k',       'linewidths': 1,    'node_shape': 's'}},
                       'conv1':     {'style': {'edgecolors': 'k',       'linewidths': 1,    'node_shape': '^'}},
                       'bias':      {'style': {'edgecolors': 'gray',    'linewidths': 0.5,  'node_shape': 'd'}},
                       'pos_enc':   {'style': {'edgecolors': 'gray',    'linewidths': 0.5,  'node_shape': 's'}},
                       'ln':        {'style': {'edgecolors': 'gray',    'linewidths': 0.5,  'node_shape': 's'}},
                       'max_pool':  {'style': {'edgecolors': 'k',       'linewidths': 1,    'node_shape': 'o', 'node_size': 1.75 * node_size}},
                       'glob_avg':  {'style': {'edgecolors': 'gray',    'linewidths': 0.5,  'node_shape': 'o', 'node_size': 2 * node_size}},
                       'concat':    {'style': {'edgecolors': 'gray',    'linewidths': 0.5,  'node_shape': '^'}},
                       'input':     {'style': {'edgecolors': 'k',       'linewidths': 1.5,  'node_shape': 's', 'node_size': 2 * node_size}},
                       'other':     {'style': {'edgecolors': 'gray',    'linewidths': 0.5,  'node_shape': 'o'}}}

        for group in node_groups:
            node_groups[group]['node_lst'] = []
            if 'node_size' not in node_groups[group]['style']:
                node_groups[group]['style']['node_size'] = node_size

        labels, node_colors = {}, []

        if vis_legend:
            node_feat = torch.cat((torch.tensor([n_primitives]).view(-1, 1),
                                   torch.tensor(primitives_order)[:, None]))
            param_shapes = [(3, 3, 1, 1)] + [None] * n_primitives
        else:
            node_feat = self.node_feat
            param_shapes = self._param_shapes

        for i, (x, sz) in enumerate(zip(node_feat.view(-1), param_shapes)):

            name = PRIMITIVES_DEEPNETS1M[x] if x < n_primitives else 'conv'
            labels[i] = name[:20] if x < n_primitives else 'conv_1x1'
            node_colors.append(primitive_colors[name])

            if name.find('conv') >= 0 and sz is not None and \
                    ((len(sz) == 4 and np.prod(sz[2:]) == 1) or len(sz) == 2):
                node_groups['conv1']['node_lst'].append(i)
            elif name in node_groups:
                node_groups[name]['node_lst'].append(i)
            else:
                node_groups['other']['node_lst'].append(i)

        if vis_legend:
            fig = plt.figure(figsize=(20, 3) if figsize is None else figsize)
            G = nx.DiGraph(np.diag(np.ones(n_primitives), 1))
            pos = {i: (3 * i * node_size, 0) for i in labels }
            pos_labels = { i: (x, y - label_offset) for i, (x, y) in pos.items() }
        else:
            fig = plt.figure(figsize=(10, 10) if figsize is None else figsize)
            G = self.nx_graph
            pos = nx.drawing.nx_pydot.graphviz_layout(G)
            pos_labels = pos

        for node_group in node_groups.values():
            nx.draw_networkx_nodes(G, pos,
                                   node_color=[node_colors[i] for i in node_group['node_lst']],
                                   nodelist=node_group['node_lst'],
                                   **node_group['style'])
        if with_labels:
            nx.draw_networkx_labels(G, pos_labels, labels, font_size=font_size)

        nx.draw_networkx_edges(G, pos, node_size=node_size,
                               width=0 if vis_legend else 1,
                               arrowsize=10,
                               alpha=0 if vis_legend else 1,
                               edge_color='white' if vis_legend else 'k',
                               arrowstyle='-|>')

        plt.grid(False)
        plt.axis('off')
        if figname is not None:
            plt.savefig(figname + '.pdf', dpi=fig.dpi)
            plt.savefig(figname + '.png', dpi=fig.dpi, transparent=True)
        plt.show()