def get_model()

in keras/backend/mxnet_backend.py [0:0]


def get_model():
    """Prepares Model class that can be used for training a Keras model with MXNet backend.
    Inherits and extends keras.engine.Model class.

    # Returns
        MXNet Model reference
    """
    import importlib
    engine = importlib.import_module('keras.engine.training')

    class Model(engine.Model):
        """The `Model` class adds training & evaluation routines to a `Network`. This class extends
        keras.engine.Model to add MXNet Module to perform training and inference with MXNet backend.
        """

        def __init__(self, *args, **kwargs):
            if 'name' not in kwargs:
                prefix = self.__class__.__name__.lower()
                name = prefix + '_' + str(get_uid(prefix))
                kwargs['name'] = name

            self.name = kwargs['name']

            super(Model, self).__init__(*args, **kwargs)

            if 'context' not in kwargs:
                kwargs['context'] = None

            if 'kvstore' not in kwargs:
                kwargs['kvstore'] = 'device'

            self._context = _get_mxnet_context(kwargs['context'])
            self._kvstore = kwargs['kvstore']

            self._data_names = None
            self._label_names = None
            self._ntrain = None
            self._train_mxnet_symbol = None
            self._train_updates = None
            self._ntest = None
            self._test_mxnet_symbol = None
            self._test_updates = None
            self._npred = None
            self._pred_mxnet_symbol = None
            self._arg_names = None
            self._aux_names = None
            self._fixed_weights = None
            self._args = None
            self._auxs = None
            self._weights_dirty = None
            self._module = None

            self.compiled = False

            if self.built:
                self._num_data = len(self.inputs)
                self._num_label = len(self.outputs) + len(self.output_names)
                # Create Module for Inference
                self._create_predict_module()
            else:
                self._num_data = None
                self._num_label = None

        def compile(self, optimizer, loss=None, metrics=None, loss_weights=None,
                    sample_weight_mode=None, **kwargs):
            super(Model, self).compile(
                optimizer, loss, metrics, loss_weights,
                sample_weight_mode, **kwargs)

            if not self.built:
                # Model is not compilable because
                # it does not know its number of inputs
                # and outputs, nor their shapes and names.
                # We will compile after the first
                # time the model gets called on training data.
                return

            # If context is passed in kwargs
            if 'context' in kwargs:
                self._context = _get_mxnet_context(kwargs['context'])

            if self.built:
                self._num_data = len(self.inputs)
                self._num_label = len(self.outputs) + len(self.output_names)

            # set the data and label
            self._data_names = [x.name for x in self.inputs if x]
            self._label_names = [x.name for x in self.targets + self.sample_weights if x]

            # set for training
            old = learning_phase()
            set_learning_phase(1)
            self._ntrain = len(self.metrics_tensors) + 1
            train_updates = [stop_gradient(x[1]) for x in self.updates]
            train_keras_symbol = group(
                [make_loss(self.total_loss)] + [stop_gradient(x)
                                                for x in self.metrics_tensors] + train_updates
            )
            bind_values = dfs_get_bind_values(train_keras_symbol)
            self._train_mxnet_symbol = train_keras_symbol.symbol
            symbol_name_map = {i.name: j.name for (_, i), j in zip(self.updates, train_updates)}
            self._train_updates = {dst.name: symbol_name_map[src.name] for dst, src in self.updates}

            # set for testing
            set_learning_phase(0)
            self._ntest = len(self.metrics_tensors) + 1
            state_updates = [x[1] for x in self.state_updates]
            test_keras_symbol = group(
                [self.total_loss] +
                [stop_gradient(x) for x in self.metrics_tensors] +
                state_updates
            )
            bind_values.update(dfs_get_bind_values(test_keras_symbol))
            self._test_mxnet_symbol = test_keras_symbol.symbol

            # set for prediction
            self._npred = len(self.outputs)
            pred_keras_symbol = group(
                self.outputs +
                [symbol for symbol in state_updates if symbol not in self.outputs]
            )
            bind_values.update(dfs_get_bind_values(pred_keras_symbol))
            self._pred_mxnet_symbol = pred_keras_symbol.symbol
            self._test_updates = {dst.name: src.name for dst, src in self.state_updates}
            set_learning_phase(old)

            # set the args and auxs
            inputs_name_set = set(self._data_names + self._label_names)
            self._arg_names = set([x for x in self._train_mxnet_symbol.list_arguments()
                                   if x not in inputs_name_set])
            self._aux_names = set(self._train_mxnet_symbol.list_auxiliary_states())

            trainable_weights = set([x.name for x in self.trainable_weights])
            self._fixed_weights = [x for x in self._arg_names if x not in trainable_weights]
            # self._args = {x: bind_values[x] for x in self._arg_names if x in bind_values}
            self._args = {}
            for x in self._arg_names:
                if x in bind_values:
                    if is_sparse(x):
                        self._args[x] = mx.nd.array(bind_values[x]).tostype('row_sparse')
                    else:
                        self._args[x] = bind_values[x]
            self._auxs = {x: bind_values[x] for x in self._aux_names if x in bind_values}
            self._weights_dirty = False

            if self._context and hasattr(self._context[0], 'device_type') and self._context[0].device_type == 'eia':
                # Only Prediction is Supported with EIA Context
                self._module = mx.mod.Module(self._pred_mxnet_symbol, data_names=self._data_names,
                                             label_names=self._label_names, context=self._context[0],
                                             fixed_param_names=self._fixed_weights)
            else:
                # set the module
                def sym_gen(phase):
                    if phase == 'train':
                        return self._train_mxnet_symbol, self._data_names, self._label_names
                    elif phase == 'test':
                        return self._test_mxnet_symbol, self._data_names, self._label_names
                    else:
                        return self._pred_mxnet_symbol, self._data_names, None

                self._module = mx.mod.BucketingModule(
                    sym_gen=sym_gen,
                    default_bucket_key='pred',
                    context=self._context,
                    fixed_param_names=self._fixed_weights)

            set_model(self)
            self.compiled = True

        def _adjust_module(self, inputs, phase):
            if not self._module:
                raise RuntimeError('MXNet Backend: You must compile your model before using it.')
            if self._num_data + self._num_label == len(inputs) - 1:
                inputs = inputs[:-1]
            elif self._num_data == len(inputs) - 1:
                inputs = inputs[:-1]
            assert self._num_data == len(inputs) or self._num_data + self._num_label == len(inputs)
            data = [mx.nd.array(x, dtype=s.dtype)
                    for (s, x) in zip(self.inputs, inputs[:self._num_data])]
            data_shapes = [mx.io.DataDesc(s.name, arr.shape, dtype=s.dtype)
                           for (s, arr) in zip(self.inputs, data)]
            if self._num_data < len(inputs):
                label = [mx.nd.array(x, dtype=s.dtype)
                         for (s, x) in zip(self.targets + self.sample_weights,
                                           inputs[self._num_data:])]
                label_shapes = [mx.io.DataDesc(s.name, arr.shape, dtype=s.dtype)
                                for (s, arr) in zip(self.targets + self.sample_weights, label)]
            else:
                label = None
                label_shapes = None

            if not self._module.binded:
                # allow prediction without compiling the model using different binding
                if phase == 'pred' and not self.compiled:
                    self._module.bind(data_shapes=data_shapes, label_shapes=None,
                                      for_training=False)
                    self._set_weights()
                else:
                    self._module.bind(data_shapes=data_shapes, label_shapes=None, for_training=True)
                    self._set_weights()
                    self._module.init_optimizer(kvstore=self._kvstore, optimizer=self.optimizer)

            # If context is EIA, we will be directly using Module rather than Bucketing Module.
            # Hence, below specialization.
            if isinstance(self._module, mx.mod.BucketingModule):
                self._module.switch_bucket(phase, data_shapes, label_shapes)

                # adjust module data shape
                if inputs[0].shape[0] != self._module._curr_module._exec_group.batch_size:
                    self._module._curr_module.reshape(data_shapes, label_shapes)
                    assert inputs[0].shape[0] == self._module._curr_module._exec_group.batch_size, \
                        'Reshape failed'
            else:
                # adjust module data shape
                if inputs[0].shape[0] != self._module._exec_group.batch_size:
                    self._module.reshape(data_shapes, label_shapes)
                    assert inputs[0].shape[0] == self._module._exec_group.batch_size, \
                        'Reshape failed'

            return data, label, phase, data_shapes, label_shapes

        def _sync_weights(self):
            if self._weights_dirty:
                args, auxs = self._module.get_params()
                for name in self._arg_names:
                    try:
                        self._args[name][:] = args[name]
                    except:
                        # when name is not in self._args (key not found)
                        self._args[name] = []
                        self._args[name][:] = args[name]
                for name in self._aux_names:
                    try:
                        self._auxs[name][:] = auxs[name]
                    except:
                        # when name is not in self._auxs (key not found)
                        self._auxs[name] = []
                        self._auxs[name][:] = auxs[name]
                self._weights_dirty = False

        def _set_weights(self, arg_params=None, auxs_params=None):
            if self._module.binded:
                self._module.set_params(self._args if arg_params is None else arg_params,
                                        self._auxs if auxs_params is None else auxs_params,
                                        allow_missing=True)
                self._weights_dirty = arg_params is not None or auxs_params is not None
            else:
                if arg_params:
                    for k in arg_params:
                        self._args[k][:] = arg_params[k]
                if auxs_params:
                    for k in auxs_params:
                        self._auxs[k][:] = auxs_params[k]
                self._weights_dirty = False

        def _update(self, updates):
            for exe in self._module._curr_module._exec_group.execs:
                outs = exe.output_dict
                args = exe.arg_dict
                for dst, src in updates.items():
                    args[dst][:] = outs[src + '_output']

        def _make_train_function(self):
            def train_function(inputs):
                self._check_trainable_weights_consistency()
                data, label, _, data_shapes, label_shapes = self._adjust_module(inputs, 'train')

                batch = mx.io.DataBatch(data=data, label=label, bucket_key='train',
                                        provide_data=data_shapes, provide_label=label_shapes)

                self._module.forward_backward(batch)
                self._module.update()
                self._update(self._train_updates)
                self._weights_dirty = True
                outs = self._module.get_outputs()[:self._ntrain]
                return [x.asnumpy().mean() for x in outs]

            # If context is EIA this should not be supported
            if self._context and hasattr(self._context[0], 'device_type') and self._context[0].device_type == 'eia':
                raise RuntimeError('MXNet Backend: Model training is not supported with MXNet EIA context.'
                                   'Use CPU/GPU.')

            self.train_function = train_function

        def _make_test_function(self):
            def test_function(inputs):
                # although this function do testing we need the training symbol
                data, label, _, data_shapes, label_shapes = self._adjust_module(inputs, 'test')

                batch = mx.io.DataBatch(data=data, label=label, bucket_key='test',
                                        provide_data=data_shapes, provide_label=label_shapes)
                self._module.forward(batch, is_train=False)
                if self._test_updates:
                    self._update(self._test_updates)
                    self._weights_dirty = True
                outs = self._module.get_outputs()[:self._ntrain]
                return [x.asnumpy().mean() for x in outs]

            # If context is EIA this should not be supported
            if self._context and hasattr(self._context[0], 'device_type') and self._context[0].device_type == 'eia':
                raise RuntimeError(
                    'MXNet Backend: Model Testing is not supported with MXNet EIA context. Use CPU/GPU.')
            self.test_function = test_function

        def _make_predict_function(self):
            def predict_function(inputs):
                # used predict only module if predict is called without compile
                if not self.compiled:
                    if self.built:
                        self._num_data = len(self.inputs)
                        self._num_label = len(self.outputs) + len(self.output_names)
                        # Create Module for Inference
                        self._create_predict_module()
                    self._module = self._predict_only_module
                    set_model(self)

                data, label, _, data_shapes, label_shapes = self._adjust_module(inputs, 'pred')
                batch = mx.io.DataBatch(data=data, label=label, bucket_key='pred',
                                        provide_data=data_shapes, provide_label=label_shapes)
                self._module.forward(batch, is_train=False)
                if self._test_updates:
                    self._update(self._test_updates)
                    self._weights_dirty = True
                outs = self._module.get_outputs()[:self._npred]
                return [x.asnumpy() for x in outs]

            self.predict_function = predict_function

        def _create_predict_module(self):
            # set the data and label
            self._data_names = [x.name for x in self.inputs if x]

            state_updates = [x[1] for x in self.state_updates]
            # set for prediction
            self._npred = len(self.outputs)
            pred_keras_symbol = group(
                self.outputs +
                [symbol for symbol in state_updates if symbol not in self.outputs]
            )
            bind_values = dfs_get_bind_values(pred_keras_symbol)
            self._pred_mxnet_symbol = pred_keras_symbol.symbol

            # set the args and auxs
            inputs_name_set = set(self._data_names)
            self._arg_names = set([x for x in self._pred_mxnet_symbol.list_arguments()
                                   if x not in inputs_name_set])
            self._aux_names = set(self._pred_mxnet_symbol.list_auxiliary_states())

            trainable_weights = set([x.name for x in self.trainable_weights])
            self._fixed_weights = [x for x in self._arg_names if x not in trainable_weights]
            self._args = {x: bind_values[x] for x in self._arg_names if x in bind_values}
            self._auxs = {x: bind_values[x] for x in self._aux_names if x in bind_values}
            self._weights_dirty = False

            # set module for prediction only
            if self._context and hasattr(self._context[0], 'device_type') and self._context[0].device_type == 'eia':
                # Only Prediction is Supported with EI Context
                self._predict_only_module = mx.mod.Module(self._pred_mxnet_symbol, data_names=self._data_names,
                                                          label_names=self._label_names, context=self._context[0],
                                                          fixed_param_names=self._fixed_weights)
            else:
                def sym_gen(phase):
                    return self._pred_mxnet_symbol, self._data_names, None

                # separate module for using predict without compiling model
                self._predict_only_module = mx.mod.BucketingModule(
                    sym_gen=sym_gen,
                    default_bucket_key='pred',
                    context=self._context,
                    fixed_param_names=self._fixed_weights)

        def set_mxnet_context(self, context):
            """Sets the mxnet context for the current Model.

            # Arguments
                context: Integer >= 2 or list of integers, number of GPUs or
                      list of GPU IDs on which to create model replicas.
            """
            if isinstance(context, (list, tuple)):
                if len(context) <= 1:
                    raise ValueError('MXNet Backend: For multi-gpu usage to be effective, '
                                     'call `multi_gpu_model` with `len(gpus) >= 2`. '
                                     'Received: `gpus=%s`' % context)
            elif isinstance(context, str):
                if not context.lower().startswith(('eia', 'cpu', 'gpu')):
                    raise ValueError('MXNet Backend: Invalid context provided - %s' % context)
            else:
                if context <= 1:
                    raise ValueError('MXNet Backend: For multi-gpu usage to be effective, '
                                     'call `multi_gpu_model` with `gpus >= 2`. '
                                     'Received: `gpus=%d`' % context)

            self._context = _get_mxnet_context(context)

    return Model