in utils/gluon/plot_network.py [0:0]
def main():
opt = parse_args()
kwargs = {'ctx': [mx.cpu()], 'pretrained': False, 'classes': 1000, 'ratio': opt.ratio}
if opt.use_se:
kwargs['use_se'] = True
logging.info("get symbol ...")
net = get_model(opt.model, **kwargs)
# Option 1
logging.info("option 1: print network ...")
logging.info(net)
# Option 2 (net must be HybridSequential, if want to plot whole graph)
logging.info("option 2: draw network ...")
net.hybridize()
net.collect_params().initialize()
x = mx.sym.var('data')
sym = net(x)
digraph = mx.viz.plot_network(sym, shape={'data':(1, 3, 224, 224)}, save_format = 'png')
digraph.view()
digraph.render()
keys = sorted(dict(net.collect_params()).keys())
logging.info(json.dumps(keys, indent=4))