def plot_network()

in python/mxnet/visualization.py [0:0]


def plot_network(symbol, title="plot", save_format='pdf', shape=None, node_attrs={},
                 hide_weights=True):
    """Creates a visualization (Graphviz digraph object) of the given computation graph.
    Graphviz must be installed for this function to work.

    Parameters
    ----------
    title: str, optional
        Title of the generated visualization.
    symbol: Symbol
        A symbol from the computation graph. The generated digraph will visualize the part
        of the computation graph required to compute `symbol`.
    shape: dict, optional
        Specifies the shape of the input tensors. If specified, the visualization will include
        the shape of the tensors between the nodes. `shape` is a dictionary mapping
        input symbol names (str) to the corresponding tensor shape (tuple).
    node_attrs: dict, optional
        Specifies the attributes for nodes in the generated visualization. `node_attrs` is
        a dictionary of Graphviz attribute names and values. For example,
            ``node_attrs={"shape":"oval","fixedsize":"false"}``
            will use oval shape for nodes and allow variable sized nodes in the visualization.
    hide_weights: bool, optional
        If True (default), then inputs with names of form *_weight (corresponding to weight
        tensors) or *_bias (corresponding to bias vectors) will be hidden for a cleaner
        visualization.

    Returns
    -------
    dot: Digraph
        A Graphviz digraph object visualizing the computation graph to compute `symbol`.

    Example
    -------
    >>> net = mx.sym.Variable('data')
    >>> net = mx.sym.FullyConnected(data=net, name='fc1', num_hidden=128)
    >>> net = mx.sym.Activation(data=net, name='relu1', act_type="relu")
    >>> net = mx.sym.FullyConnected(data=net, name='fc2', num_hidden=10)
    >>> net = mx.sym.SoftmaxOutput(data=net, name='out')
    >>> digraph = mx.viz.plot_network(net, shape={'data':(100,200)},
    ... node_attrs={"fixedsize":"false"})
    >>> digraph.view()
    """
    # todo add shape support
    try:
        from graphviz import Digraph
    except:
        raise ImportError("Draw network requires graphviz library")
    if not isinstance(symbol, Symbol):
        raise TypeError("symbol must be a Symbol")
    draw_shape = False
    if shape is not None:
        draw_shape = True
        interals = symbol.get_internals()
        _, out_shapes, _ = interals.infer_shape(**shape)
        if out_shapes is None:
            raise ValueError("Input shape is incomplete")
        shape_dict = dict(zip(interals.list_outputs(), out_shapes))
    conf = json.loads(symbol.tojson())
    nodes = conf["nodes"]
    # default attributes of node
    node_attr = {"shape": "box", "fixedsize": "true",
                 "width": "1.3", "height": "0.8034", "style": "filled"}
    # merge the dict provided by user and the default one
    node_attr.update(node_attrs)
    dot = Digraph(name=title, format=save_format)
    # color map
    cm = ("#8dd3c7", "#fb8072", "#ffffb3", "#bebada", "#80b1d3",
          "#fdb462", "#b3de69", "#fccde5")

    def looks_like_weight(name):
        """Internal helper to figure out if node should be hidden with `hide_weights`.
        """
        if name.endswith("_weight"):
            return True
        if name.endswith("_bias"):
            return True
        if name.endswith("_beta") or name.endswith("_gamma") or \
	   name.endswith("_moving_var") or name.endswith("_moving_mean"):
            return True
        return False

    # make nodes
    hidden_nodes = set()
    for node in nodes:
        op = node["op"]
        name = node["name"]
        # input data
        attr = copy.deepcopy(node_attr)
        label = name

        if op == "null":
            if looks_like_weight(node["name"]):
                if hide_weights:
                    hidden_nodes.add(node["name"])
                # else we don't render a node, but
                # don't add it to the hidden_nodes set
                # so it gets rendered as an empty oval
                continue
            attr["shape"] = "oval" # inputs get their own shape
            label = node["name"]
            attr["fillcolor"] = cm[0]
        elif op == "Convolution":
            label = r"Convolution\n%s/%s, %s" % ("x".join(_str2tuple(node["attr"]["kernel"])),
                                                 "x".join(_str2tuple(node["attr"]["stride"]))
                                                 if "stride" in node["attr"] else "1",
                                                 node["attr"]["num_filter"])
            attr["fillcolor"] = cm[1]
        elif op == "FullyConnected":
            label = r"FullyConnected\n%s" % node["attr"]["num_hidden"]
            attr["fillcolor"] = cm[1]
        elif op == "BatchNorm":
            attr["fillcolor"] = cm[3]
        elif op == "Activation" or op == "LeakyReLU":
            label = r"%s\n%s" % (op, node["attr"]["act_type"])
            attr["fillcolor"] = cm[2]
        elif op == "Pooling":
            label = r"Pooling\n%s, %s/%s" % (node["attr"]["pool_type"],
                                             "x".join(_str2tuple(node["attr"]["kernel"])),
                                             "x".join(_str2tuple(node["attr"]["stride"]))
                                             if "stride" in node["attr"] else "1")
            attr["fillcolor"] = cm[4]
        elif op == "Concat" or op == "Flatten" or op == "Reshape":
            attr["fillcolor"] = cm[5]
        elif op == "Softmax":
            attr["fillcolor"] = cm[6]
        else:
            attr["fillcolor"] = cm[7]
            if op == "Custom":
                label = node["attr"]["op_type"]

        dot.node(name=name, label=label, **attr)

    # add edges
    for node in nodes:          # pylint: disable=too-many-nested-blocks
        op = node["op"]
        name = node["name"]
        if op == "null":
            continue
        else:
            inputs = node["inputs"]
            for item in inputs:
                input_node = nodes[item[0]]
                input_name = input_node["name"]
                if input_name not in hidden_nodes:
                    attr = {"dir": "back", 'arrowtail':'open'}
                    # add shapes
                    if draw_shape:
                        if input_node["op"] != "null":
                            key = input_name + "_output"
                            if "attr" in input_node:
                                params = input_node["attr"]
                                if "num_outputs" in params:
                                    key += str(int(params["num_outputs"]) - 1)
                                    params["num_outputs"] = int(params["num_outputs"]) - 1
                            shape = shape_dict[key][1:]
                            label = "x".join([str(x) for x in shape])
                            attr["label"] = label
                        else:
                            key = input_name
                            shape = shape_dict[key][1:]
                            label = "x".join([str(x) for x in shape])
                            attr["label"] = label
                    dot.edge(tail_name=name, head_name=input_name, **attr)

    return dot