in python/mxnet/model.py [0:0]
def predict(self, X, num_batch=None, return_data=False, reset=True):
"""Run the prediction, always only use one device.
Parameters
----------
X : mxnet.DataIter
num_batch : int or None
The number of batch to run. Go though all batches if ``None``.
Returns
-------
y : numpy.ndarray or a list of numpy.ndarray if the network has multiple outputs.
The predicted value of the output.
"""
X = self._init_iter(X, None, is_train=False)
if reset:
X.reset()
data_shapes = X.provide_data
data_names = [x[0] for x in data_shapes]
type_dict = dict((key, value.dtype) for (key, value) in self.arg_params.items())
for x in X.provide_data:
if isinstance(x, DataDesc):
type_dict[x.name] = x.dtype
else:
type_dict[x[0]] = mx_real_t
self._init_predictor(data_shapes, type_dict)
batch_size = X.batch_size
data_arrays = [self._pred_exec.arg_dict[name] for name in data_names]
output_list = [[] for _ in range(len(self._pred_exec.outputs))]
if return_data:
data_list = [[] for _ in X.provide_data]
label_list = [[] for _ in X.provide_label]
i = 0
for batch in X:
_load_data(batch, data_arrays)
self._pred_exec.forward(is_train=False)
padded = batch.pad
real_size = batch_size - padded
for o_list, o_nd in zip(output_list, self._pred_exec.outputs):
o_list.append(o_nd[0:real_size].asnumpy())
if return_data:
for j, x in enumerate(batch.data):
data_list[j].append(x[0:real_size].asnumpy())
for j, x in enumerate(batch.label):
label_list[j].append(x[0:real_size].asnumpy())
i += 1
if num_batch is not None and i == num_batch:
break
outputs = [np.concatenate(x) for x in output_list]
if len(outputs) == 1:
outputs = outputs[0]
if return_data:
data = [np.concatenate(x) for x in data_list]
label = [np.concatenate(x) for x in label_list]
if len(data) == 1:
data = data[0]
if len(label) == 1:
label = label[0]
return outputs, data, label
else:
return outputs