in python/mxnet/model.py [0:0]
def _train_multi_device(symbol, ctx, arg_names, param_names, aux_names,
arg_params, aux_params,
begin_epoch, end_epoch, epoch_size, optimizer,
kvstore, update_on_kvstore,
train_data, eval_data=None, eval_metric=None,
epoch_end_callback=None, batch_end_callback=None,
logger=None, work_load_list=None, monitor=None,
eval_end_callback=None,
eval_batch_end_callback=None, sym_gen=None):
"""Internal training function on multiple devices.
This function will also work for single device as well.
Parameters
----------
symbol : Symbol
The network configuration.
ctx : list of Context
The training devices.
arg_names: list of str
Name of all arguments of the network.
param_names: list of str
Name of all trainable parameters of the network.
aux_names: list of str
Name of all auxiliary states of the network.
arg_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's weights.
aux_params : dict of str to NDArray
Model parameter, dict of name to NDArray of net's auxiliary states.
begin_epoch : int
The begining training epoch.
end_epoch : int
The end training epoch.
epoch_size : int, optional
Number of batches in a epoch. In default, it is set to
``ceil(num_train_examples / batch_size)``.
optimizer : Optimizer
The optimization algorithm
train_data : DataIter
Training data iterator.
eval_data : DataIter
Validation data iterator.
eval_metric : EvalMetric
An evaluation function or a list of evaluation functions.
epoch_end_callback : callable(epoch, symbol, arg_params, aux_states)
A callback that is invoked at end of each epoch.
This can be used to checkpoint model each epoch.
batch_end_callback : callable(BatchEndParams)
A callback that is invoked at end of each batch.
This can be used to measure speed, get result from evaluation metric. etc.
kvstore : KVStore
The KVStore.
update_on_kvstore : bool
Whether or not perform weight updating on kvstore.
logger : logging logger
When not specified, default logger will be used.
work_load_list : list of float or int, optional
The list of work load for different devices,
in the same order as ``ctx``.
monitor : Monitor, optional
Monitor installed to executor,
for monitoring outputs, weights, and gradients for debugging.
Notes
-----
- This function will inplace update the NDArrays in `arg_params` and `aux_states`.
"""
if logger is None:
logger = logging
executor_manager = DataParallelExecutorManager(symbol=symbol,
sym_gen=sym_gen,
ctx=ctx,
train_data=train_data,
param_names=param_names,
arg_names=arg_names,
aux_names=aux_names,
work_load_list=work_load_list,
logger=logger)
if monitor:
executor_manager.install_monitor(monitor)
executor_manager.set_params(arg_params, aux_params)
if not update_on_kvstore:
updater = get_updater(optimizer)
if kvstore:
_initialize_kvstore(kvstore=kvstore,
param_arrays=executor_manager.param_arrays,
arg_params=arg_params,
param_names=executor_manager.param_names,
update_on_kvstore=update_on_kvstore)
if update_on_kvstore:
kvstore.set_optimizer(optimizer)
# Now start training
train_data.reset()
for epoch in range(begin_epoch, end_epoch):
# Training phase
tic = time.time()
eval_metric.reset()
nbatch = 0
# Iterate over training data.
while True:
do_reset = True
for data_batch in train_data:
executor_manager.load_data_batch(data_batch)
if monitor is not None:
monitor.tic()
executor_manager.forward(is_train=True)
executor_manager.backward()
if update_on_kvstore:
_update_params_on_kvstore(executor_manager.param_arrays,
executor_manager.grad_arrays,
kvstore, executor_manager.param_names)
else:
_update_params(executor_manager.param_arrays,
executor_manager.grad_arrays,
updater=updater,
num_device=len(ctx),
kvstore=kvstore,
param_names=executor_manager.param_names)
if monitor is not None:
monitor.toc_print()
# evaluate at end, so we can lazy copy
executor_manager.update_metric(eval_metric, data_batch.label)
nbatch += 1
# batch callback (for print purpose)
if batch_end_callback is not None:
batch_end_params = BatchEndParam(epoch=epoch,
nbatch=nbatch,
eval_metric=eval_metric,
locals=locals())
_multiple_callbacks(batch_end_callback, batch_end_params)
# this epoch is done possibly earlier
if epoch_size is not None and nbatch >= epoch_size:
do_reset = False
break
if do_reset:
logger.info('Epoch[%d] Resetting Data Iterator', epoch)
train_data.reset()
# this epoch is done
if epoch_size is None or nbatch >= epoch_size:
break
toc = time.time()
logger.info('Epoch[%d] Time cost=%.3f', epoch, (toc - tic))
if epoch_end_callback or epoch + 1 == end_epoch:
executor_manager.copy_to(arg_params, aux_params)
_multiple_callbacks(epoch_end_callback, epoch, symbol, arg_params, aux_params)
# evaluation
if eval_data:
eval_metric.reset()
eval_data.reset()
total_num_batch = 0
for i, eval_batch in enumerate(eval_data):
executor_manager.load_data_batch(eval_batch)
executor_manager.forward(is_train=False)
executor_manager.update_metric(eval_metric, eval_batch.label)
if eval_batch_end_callback is not None:
batch_end_params = BatchEndParam(epoch=epoch,
nbatch=i,
eval_metric=eval_metric,
locals=locals())
_multiple_callbacks(eval_batch_end_callback, batch_end_params)
total_num_batch += 1
if eval_end_callback is not None:
eval_end_params = BatchEndParam(epoch=epoch,
nbatch=total_num_batch,
eval_metric=eval_metric,
locals=locals())
_multiple_callbacks(eval_end_callback, eval_end_params)
eval_data.reset()
# end of all epochs
return