def main()

in utils/gluon/plot_network.py [0:0]


def main():
    opt = parse_args()

    kwargs = {'ctx': [mx.cpu()], 'pretrained': False, 'classes': 1000, 'ratio': opt.ratio}
    
    if opt.use_se:
        kwargs['use_se'] = True

    logging.info("get symbol ...")
    net = get_model(opt.model, **kwargs)

    # Option 1
    logging.info("option 1: print network ...")
    logging.info(net)

    # Option 2 (net must be HybridSequential, if want to plot whole graph)
    logging.info("option 2: draw network ...")
    net.hybridize()
    net.collect_params().initialize()

    x = mx.sym.var('data')
    sym = net(x)
    digraph = mx.viz.plot_network(sym, shape={'data':(1, 3, 224, 224)}, save_format = 'png')
    digraph.view()
    digraph.render()

    keys = sorted(dict(net.collect_params()).keys())
    logging.info(json.dumps(keys, indent=4))