in src/neuron-gatherinfo/tf_neuron_check_model.py [0:0]
def load_mxnet_model(self, path):
import mxnet as mx
if mx.__version__ != "1.5.1":
try:
import mxnetneuron as mxn
except:
raise "Please install mxnetneuron package."
self.framework = 'MXNET'
self.neuron_optype = "_neuron_subgraph_op"
self.excl_types = ['null']
self.addl_support = [self.neuron_optype]
sym, args, auxs = mx.model.load_checkpoint(path, 0)
nodes = json.loads(sym.tojson())["nodes"]
nodes = [node for node in nodes if node['op'] not in self.excl_types]
self.nodetypes = [node['op'] for node in nodes]
self.nodenames = [node['name'] for node in nodes]
neuron_nodes_tmp = [node for node in nodes if node['op'] == self.neuron_optype]
self.neuron_nodes = [(node['name'], bytearray(args[node['name']+"_neuronbin"].asnumpy()), self.get_mx_subgraph_types_names(node)) for node in neuron_nodes_tmp]