def load_mxnet_model()

in src/neuron-gatherinfo/mx_neuron_check_model.py [0:0]


  def load_mxnet_model(self, path):      
    import mxnet as mx
    if mx.__version__ != "1.5.1":
      try:
        import mxnetneuron as mxn
      except:
        raise "Please install mxnetneuron package."
    self.framework = 'MXNET'
    self.neuron_optype = "_neuron_subgraph_op"
    self.excl_types = ['null']
    self.addl_support = [self.neuron_optype]
    sym, args, auxs = mx.model.load_checkpoint(path, 0)
    nodes = json.loads(sym.tojson())["nodes"]
    nodes = [node for node in nodes if node['op'] not in self.excl_types]
    self.nodetypes = [node['op'] for node in nodes]
    self.nodenames = [node['name'] for node in nodes]
    neuron_nodes_tmp = [node for node in nodes if node['op'] == self.neuron_optype]
    self.neuron_nodes = [(node['name'], bytearray(args[node['name']+"_neuronbin"].asnumpy()), self.get_mx_subgraph_types_names(node)) for node in neuron_nodes_tmp]