in source/cf/defaults/lambdas/sputnik-gg-ml-inference-squeezenet-demo-python/load_model.py [0:0]
def __init__(self, synset_path, network_prefix, params_url=None,
symbol_url=None, synset_url=None, context=mx.cpu(),
label_names=['prob_label'], input_shapes=[('data', (1, 3, 224, 224))]):
# Download the symbol set and network if URLs are provided
if params_url is not None:
fetched_file = urllib2.urlopen(params_url)
with open(network_prefix + "-0000.params", 'wb') as output:
output.write(fetched_file.read())
if symbol_url is not None:
fetched_file = urllib2.urlopen(symbol_url)
with open(network_prefix + "-symbol.json", 'wb') as output:
output.write(fetched_file.read())
if synset_url is not None:
fetched_file = urllib2.urlopen(synset_url)
with open(synset_path, 'wb') as output:
output.write(fetched_file.read())
# Load the symbols for the networks
with open(synset_path, 'r') as f:
self.synsets = [l.rstrip() for l in f]
# Load the network parameters from default epoch 0
sym, arg_params, aux_params = mx.model.load_checkpoint(network_prefix, 0)
# Load the network into an MXNet module and bind the corresponding parameters
self.mod = mx.mod.Module(symbol=sym, label_names=label_names, context=context)
self.mod.bind(for_training=False, data_shapes=input_shapes)
self.mod.set_params(arg_params, aux_params)
self.camera = None