def make_dot()

in tensorwatch/model_graph/hiddenlayer/pytorch_builder_grad.py [0:0]


def make_dot(var, params, dot):
    """ Produces Graphviz representation of PyTorch autograd graph.
    
    Blue nodes are trainable Variables (weights, bias).
    Orange node are saved tensors for the backward pass.
    
    Args:
        var: output Variable
        params: list of (name, Parameters)
    """
    param_map2 = {k:v for k, v in params}
    print(param_map2)  
    param_map = {id(v): k for k, v in params}



    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')

    # dot = Digraph(
    #     filename='network', 
    #     format='pdf',
    #     node_attr=node_attr, 
    #     graph_attr=dict(size="12,12"))
    seen = set()
    
    def add_nodes(dot, var):
        if var not in seen:
            
            node_id = str(id(var))
             
            if torch.is_tensor(var):
                node_label = "saved tensor\n{}".format(tuple(var.size()))
                add_node2dot(dot, var, node_id, node_label, op=None)
                
            elif hasattr(var, 'variable'):
                variable_name = param_map.get(id(var.variable))
                variable_size = tuple(var.variable.size())
                node_name = "{}\n{}".format(variable_name, variable_size)
                add_node2dot(dot, var, node_id, node_name, op=None)
                
            else:
                node_label = type(var).__name__.replace('Backward', '')
                add_node2dot(dot, var, node_id, node_label, op=None)
                
            seen.add(var)
            
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.add_edge_by_id(str(id(u[0])), str(id(var)), None)
                        add_nodes(dot, u[0])
                        
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.add_edge_by_id(str(id(t)), str(id(var)), None)
                    add_nodes(dot, t)

    add_nodes(dot, var.grad_fn)
    
    return dot