def predict()

in python/mxnet/model.py [0:0]


    def predict(self, X, num_batch=None, return_data=False, reset=True):
        """Run the prediction, always only use one device.

        Parameters
        ----------
        X : mxnet.DataIter
        num_batch : int or None
            The number of batch to run. Go though all batches if ``None``.
        Returns
        -------
        y : numpy.ndarray or a list of numpy.ndarray if the network has multiple outputs.
            The predicted value of the output.
        """
        X = self._init_iter(X, None, is_train=False)

        if reset:
            X.reset()
        data_shapes = X.provide_data
        data_names = [x[0] for x in data_shapes]
        type_dict = dict((key, value.dtype) for (key, value) in self.arg_params.items())
        for x in X.provide_data:
            if isinstance(x, DataDesc):
                type_dict[x.name] = x.dtype
            else:
                type_dict[x[0]] = mx_real_t

        self._init_predictor(data_shapes, type_dict)
        batch_size = X.batch_size
        data_arrays = [self._pred_exec.arg_dict[name] for name in data_names]
        output_list = [[] for _ in range(len(self._pred_exec.outputs))]
        if return_data:
            data_list = [[] for _ in X.provide_data]
            label_list = [[] for _ in X.provide_label]

        i = 0
        for batch in X:

            _load_data(batch, data_arrays)
            self._pred_exec.forward(is_train=False)
            padded = batch.pad
            real_size = batch_size - padded

            for o_list, o_nd in zip(output_list, self._pred_exec.outputs):
                o_list.append(o_nd[0:real_size].asnumpy())

            if return_data:
                for j, x in enumerate(batch.data):
                    data_list[j].append(x[0:real_size].asnumpy())
                for j, x in enumerate(batch.label):
                    label_list[j].append(x[0:real_size].asnumpy())
            i += 1
            if num_batch is not None and i == num_batch:
                break

        outputs = [np.concatenate(x) for x in output_list]
        if len(outputs) == 1:
            outputs = outputs[0]

        if return_data:
            data = [np.concatenate(x) for x in data_list]
            label = [np.concatenate(x) for x in label_list]
            if len(data) == 1:
                data = data[0]
            if len(label) == 1:
                label = label[0]
            return outputs, data, label
        else:
            return outputs