in keras/backend/mxnet_backend.py [0:0]
def get_model():
"""Prepares Model class that can be used for training a Keras model with MXNet backend.
Inherits and extends keras.engine.Model class.
# Returns
MXNet Model reference
"""
import importlib
engine = importlib.import_module('keras.engine.training')
class Model(engine.Model):
"""The `Model` class adds training & evaluation routines to a `Network`. This class extends
keras.engine.Model to add MXNet Module to perform training and inference with MXNet backend.
"""
def __init__(self, *args, **kwargs):
if 'name' not in kwargs:
prefix = self.__class__.__name__.lower()
name = prefix + '_' + str(get_uid(prefix))
kwargs['name'] = name
self.name = kwargs['name']
super(Model, self).__init__(*args, **kwargs)
if 'context' not in kwargs:
kwargs['context'] = None
if 'kvstore' not in kwargs:
kwargs['kvstore'] = 'device'
self._context = _get_mxnet_context(kwargs['context'])
self._kvstore = kwargs['kvstore']
self._data_names = None
self._label_names = None
self._ntrain = None
self._train_mxnet_symbol = None
self._train_updates = None
self._ntest = None
self._test_mxnet_symbol = None
self._test_updates = None
self._npred = None
self._pred_mxnet_symbol = None
self._arg_names = None
self._aux_names = None
self._fixed_weights = None
self._args = None
self._auxs = None
self._weights_dirty = None
self._module = None
self.compiled = False
if self.built:
self._num_data = len(self.inputs)
self._num_label = len(self.outputs) + len(self.output_names)
# Create Module for Inference
self._create_predict_module()
else:
self._num_data = None
self._num_label = None
def compile(self, optimizer, loss=None, metrics=None, loss_weights=None,
sample_weight_mode=None, **kwargs):
super(Model, self).compile(
optimizer, loss, metrics, loss_weights,
sample_weight_mode, **kwargs)
if not self.built:
# Model is not compilable because
# it does not know its number of inputs
# and outputs, nor their shapes and names.
# We will compile after the first
# time the model gets called on training data.
return
# If context is passed in kwargs
if 'context' in kwargs:
self._context = _get_mxnet_context(kwargs['context'])
if self.built:
self._num_data = len(self.inputs)
self._num_label = len(self.outputs) + len(self.output_names)
# set the data and label
self._data_names = [x.name for x in self.inputs if x]
self._label_names = [x.name for x in self.targets + self.sample_weights if x]
# set for training
old = learning_phase()
set_learning_phase(1)
self._ntrain = len(self.metrics_tensors) + 1
train_updates = [stop_gradient(x[1]) for x in self.updates]
train_keras_symbol = group(
[make_loss(self.total_loss)] + [stop_gradient(x)
for x in self.metrics_tensors] + train_updates
)
bind_values = dfs_get_bind_values(train_keras_symbol)
self._train_mxnet_symbol = train_keras_symbol.symbol
symbol_name_map = {i.name: j.name for (_, i), j in zip(self.updates, train_updates)}
self._train_updates = {dst.name: symbol_name_map[src.name] for dst, src in self.updates}
# set for testing
set_learning_phase(0)
self._ntest = len(self.metrics_tensors) + 1
state_updates = [x[1] for x in self.state_updates]
test_keras_symbol = group(
[self.total_loss] +
[stop_gradient(x) for x in self.metrics_tensors] +
state_updates
)
bind_values.update(dfs_get_bind_values(test_keras_symbol))
self._test_mxnet_symbol = test_keras_symbol.symbol
# set for prediction
self._npred = len(self.outputs)
pred_keras_symbol = group(
self.outputs +
[symbol for symbol in state_updates if symbol not in self.outputs]
)
bind_values.update(dfs_get_bind_values(pred_keras_symbol))
self._pred_mxnet_symbol = pred_keras_symbol.symbol
self._test_updates = {dst.name: src.name for dst, src in self.state_updates}
set_learning_phase(old)
# set the args and auxs
inputs_name_set = set(self._data_names + self._label_names)
self._arg_names = set([x for x in self._train_mxnet_symbol.list_arguments()
if x not in inputs_name_set])
self._aux_names = set(self._train_mxnet_symbol.list_auxiliary_states())
trainable_weights = set([x.name for x in self.trainable_weights])
self._fixed_weights = [x for x in self._arg_names if x not in trainable_weights]
# self._args = {x: bind_values[x] for x in self._arg_names if x in bind_values}
self._args = {}
for x in self._arg_names:
if x in bind_values:
if is_sparse(x):
self._args[x] = mx.nd.array(bind_values[x]).tostype('row_sparse')
else:
self._args[x] = bind_values[x]
self._auxs = {x: bind_values[x] for x in self._aux_names if x in bind_values}
self._weights_dirty = False
if self._context and hasattr(self._context[0], 'device_type') and self._context[0].device_type == 'eia':
# Only Prediction is Supported with EIA Context
self._module = mx.mod.Module(self._pred_mxnet_symbol, data_names=self._data_names,
label_names=self._label_names, context=self._context[0],
fixed_param_names=self._fixed_weights)
else:
# set the module
def sym_gen(phase):
if phase == 'train':
return self._train_mxnet_symbol, self._data_names, self._label_names
elif phase == 'test':
return self._test_mxnet_symbol, self._data_names, self._label_names
else:
return self._pred_mxnet_symbol, self._data_names, None
self._module = mx.mod.BucketingModule(
sym_gen=sym_gen,
default_bucket_key='pred',
context=self._context,
fixed_param_names=self._fixed_weights)
set_model(self)
self.compiled = True
def _adjust_module(self, inputs, phase):
if not self._module:
raise RuntimeError('MXNet Backend: You must compile your model before using it.')
if self._num_data + self._num_label == len(inputs) - 1:
inputs = inputs[:-1]
elif self._num_data == len(inputs) - 1:
inputs = inputs[:-1]
assert self._num_data == len(inputs) or self._num_data + self._num_label == len(inputs)
data = [mx.nd.array(x, dtype=s.dtype)
for (s, x) in zip(self.inputs, inputs[:self._num_data])]
data_shapes = [mx.io.DataDesc(s.name, arr.shape, dtype=s.dtype)
for (s, arr) in zip(self.inputs, data)]
if self._num_data < len(inputs):
label = [mx.nd.array(x, dtype=s.dtype)
for (s, x) in zip(self.targets + self.sample_weights,
inputs[self._num_data:])]
label_shapes = [mx.io.DataDesc(s.name, arr.shape, dtype=s.dtype)
for (s, arr) in zip(self.targets + self.sample_weights, label)]
else:
label = None
label_shapes = None
if not self._module.binded:
# allow prediction without compiling the model using different binding
if phase == 'pred' and not self.compiled:
self._module.bind(data_shapes=data_shapes, label_shapes=None,
for_training=False)
self._set_weights()
else:
self._module.bind(data_shapes=data_shapes, label_shapes=None, for_training=True)
self._set_weights()
self._module.init_optimizer(kvstore=self._kvstore, optimizer=self.optimizer)
# If context is EIA, we will be directly using Module rather than Bucketing Module.
# Hence, below specialization.
if isinstance(self._module, mx.mod.BucketingModule):
self._module.switch_bucket(phase, data_shapes, label_shapes)
# adjust module data shape
if inputs[0].shape[0] != self._module._curr_module._exec_group.batch_size:
self._module._curr_module.reshape(data_shapes, label_shapes)
assert inputs[0].shape[0] == self._module._curr_module._exec_group.batch_size, \
'Reshape failed'
else:
# adjust module data shape
if inputs[0].shape[0] != self._module._exec_group.batch_size:
self._module.reshape(data_shapes, label_shapes)
assert inputs[0].shape[0] == self._module._exec_group.batch_size, \
'Reshape failed'
return data, label, phase, data_shapes, label_shapes
def _sync_weights(self):
if self._weights_dirty:
args, auxs = self._module.get_params()
for name in self._arg_names:
try:
self._args[name][:] = args[name]
except:
# when name is not in self._args (key not found)
self._args[name] = []
self._args[name][:] = args[name]
for name in self._aux_names:
try:
self._auxs[name][:] = auxs[name]
except:
# when name is not in self._auxs (key not found)
self._auxs[name] = []
self._auxs[name][:] = auxs[name]
self._weights_dirty = False
def _set_weights(self, arg_params=None, auxs_params=None):
if self._module.binded:
self._module.set_params(self._args if arg_params is None else arg_params,
self._auxs if auxs_params is None else auxs_params,
allow_missing=True)
self._weights_dirty = arg_params is not None or auxs_params is not None
else:
if arg_params:
for k in arg_params:
self._args[k][:] = arg_params[k]
if auxs_params:
for k in auxs_params:
self._auxs[k][:] = auxs_params[k]
self._weights_dirty = False
def _update(self, updates):
for exe in self._module._curr_module._exec_group.execs:
outs = exe.output_dict
args = exe.arg_dict
for dst, src in updates.items():
args[dst][:] = outs[src + '_output']
def _make_train_function(self):
def train_function(inputs):
self._check_trainable_weights_consistency()
data, label, _, data_shapes, label_shapes = self._adjust_module(inputs, 'train')
batch = mx.io.DataBatch(data=data, label=label, bucket_key='train',
provide_data=data_shapes, provide_label=label_shapes)
self._module.forward_backward(batch)
self._module.update()
self._update(self._train_updates)
self._weights_dirty = True
outs = self._module.get_outputs()[:self._ntrain]
return [x.asnumpy().mean() for x in outs]
# If context is EIA this should not be supported
if self._context and hasattr(self._context[0], 'device_type') and self._context[0].device_type == 'eia':
raise RuntimeError('MXNet Backend: Model training is not supported with MXNet EIA context.'
'Use CPU/GPU.')
self.train_function = train_function
def _make_test_function(self):
def test_function(inputs):
# although this function do testing we need the training symbol
data, label, _, data_shapes, label_shapes = self._adjust_module(inputs, 'test')
batch = mx.io.DataBatch(data=data, label=label, bucket_key='test',
provide_data=data_shapes, provide_label=label_shapes)
self._module.forward(batch, is_train=False)
if self._test_updates:
self._update(self._test_updates)
self._weights_dirty = True
outs = self._module.get_outputs()[:self._ntrain]
return [x.asnumpy().mean() for x in outs]
# If context is EIA this should not be supported
if self._context and hasattr(self._context[0], 'device_type') and self._context[0].device_type == 'eia':
raise RuntimeError(
'MXNet Backend: Model Testing is not supported with MXNet EIA context. Use CPU/GPU.')
self.test_function = test_function
def _make_predict_function(self):
def predict_function(inputs):
# used predict only module if predict is called without compile
if not self.compiled:
if self.built:
self._num_data = len(self.inputs)
self._num_label = len(self.outputs) + len(self.output_names)
# Create Module for Inference
self._create_predict_module()
self._module = self._predict_only_module
set_model(self)
data, label, _, data_shapes, label_shapes = self._adjust_module(inputs, 'pred')
batch = mx.io.DataBatch(data=data, label=label, bucket_key='pred',
provide_data=data_shapes, provide_label=label_shapes)
self._module.forward(batch, is_train=False)
if self._test_updates:
self._update(self._test_updates)
self._weights_dirty = True
outs = self._module.get_outputs()[:self._npred]
return [x.asnumpy() for x in outs]
self.predict_function = predict_function
def _create_predict_module(self):
# set the data and label
self._data_names = [x.name for x in self.inputs if x]
state_updates = [x[1] for x in self.state_updates]
# set for prediction
self._npred = len(self.outputs)
pred_keras_symbol = group(
self.outputs +
[symbol for symbol in state_updates if symbol not in self.outputs]
)
bind_values = dfs_get_bind_values(pred_keras_symbol)
self._pred_mxnet_symbol = pred_keras_symbol.symbol
# set the args and auxs
inputs_name_set = set(self._data_names)
self._arg_names = set([x for x in self._pred_mxnet_symbol.list_arguments()
if x not in inputs_name_set])
self._aux_names = set(self._pred_mxnet_symbol.list_auxiliary_states())
trainable_weights = set([x.name for x in self.trainable_weights])
self._fixed_weights = [x for x in self._arg_names if x not in trainable_weights]
self._args = {x: bind_values[x] for x in self._arg_names if x in bind_values}
self._auxs = {x: bind_values[x] for x in self._aux_names if x in bind_values}
self._weights_dirty = False
# set module for prediction only
if self._context and hasattr(self._context[0], 'device_type') and self._context[0].device_type == 'eia':
# Only Prediction is Supported with EI Context
self._predict_only_module = mx.mod.Module(self._pred_mxnet_symbol, data_names=self._data_names,
label_names=self._label_names, context=self._context[0],
fixed_param_names=self._fixed_weights)
else:
def sym_gen(phase):
return self._pred_mxnet_symbol, self._data_names, None
# separate module for using predict without compiling model
self._predict_only_module = mx.mod.BucketingModule(
sym_gen=sym_gen,
default_bucket_key='pred',
context=self._context,
fixed_param_names=self._fixed_weights)
def set_mxnet_context(self, context):
"""Sets the mxnet context for the current Model.
# Arguments
context: Integer >= 2 or list of integers, number of GPUs or
list of GPU IDs on which to create model replicas.
"""
if isinstance(context, (list, tuple)):
if len(context) <= 1:
raise ValueError('MXNet Backend: For multi-gpu usage to be effective, '
'call `multi_gpu_model` with `len(gpus) >= 2`. '
'Received: `gpus=%s`' % context)
elif isinstance(context, str):
if not context.lower().startswith(('eia', 'cpu', 'gpu')):
raise ValueError('MXNet Backend: Invalid context provided - %s' % context)
else:
if context <= 1:
raise ValueError('MXNet Backend: For multi-gpu usage to be effective, '
'call `multi_gpu_model` with `gpus >= 2`. '
'Received: `gpus=%d`' % context)
self._context = _get_mxnet_context(context)
return Model