in sockeye/train.py [0:0]
def train(args: argparse.Namespace, custom_metrics_logger: Optional[Callable] = None,
checkpoint_callback: Optional[Callable] = None) -> training.TrainState:
"""
:param custom_metrics_logger: Optional custom metrics logging function. If supplied, takes care of metrics produced
during training in a custom way. It should accept a list or a dictionary of
(metric name, metric value) pairs, and an optional global_step/checkpoint parameter.
:param checkpoint_callback: An optional callback function (int -> None). The function will be called
+ each time a checkpoint has been reached
"""
if args.dry_run:
# Modify arguments so that we write to a temporary directory and
# perform 0 training iterations
temp_dir = tempfile.TemporaryDirectory() # Will be automatically removed
args.output = temp_dir.name
args.max_updates = 0
# Automatic Mixed Precision training
using_amp = False
if args.amp:
using_amp = True
amp.init()
# When using Horovod, multiple workers (instances of sockeye.train) are
# launched via MPI. Each worker has a rank (unique among all workers in the
# training run) and a local rank (unique on the current host). For example,
# running on 2 hosts with 4 slots each will assign ranks 0-7 and local ranks
# 0-3.
console_level = None
if args.horovod:
if horovod_mpi.hvd is None or horovod_mpi.MPI is None:
raise RuntimeError('Horovod training requires the following packages to be installed: horovod mpi4py')
# Unless explicitly set otherwise, use NCCL for same-host
# allreduce/allgather and MPI for cross-host allreduce/allgather.
if C.HOROVOD_HIERARCHICAL_ALLREDUCE not in os.environ:
os.environ[C.HOROVOD_HIERARCHICAL_ALLREDUCE] = '1'
if C.HOROVOD_HIERARCHICAL_ALLGATHER not in os.environ:
os.environ[C.HOROVOD_HIERARCHICAL_ALLGATHER] = '1'
horovod_mpi.hvd.init()
# Each worker uses a separate output directory. The primary worker
# (rank 0) writes files to the root of the output directory (standard
# behavior). Secondary workers write files to rank-named
# sub-directories.
if horovod_mpi.hvd.rank() > 0:
args.output = os.path.join(args.output, C.HOROVOD_SECONDARY_WORKERS_DIRNAME, str(horovod_mpi.hvd.rank()))
# Do not keep redundant copies of the checkpoint history
args.keep_last_params = 1
# If requested, suppress console output for secondary workers
if args.quiet_secondary_workers:
args.quiet = True
console_level = args.loglevel_secondary_workers
check_arg_compatibility(args)
output_folder = os.path.abspath(args.output)
resume_training = check_resume(args, output_folder)
setup_main_logger(file_logging=not args.no_logfile,
console=not args.quiet,
path=os.path.join(output_folder, C.LOG_NAME),
level=args.loglevel,
console_level=console_level)
utils.log_basic_info(args)
arguments.save_args(args, os.path.join(output_folder, C.ARGS_STATE_NAME))
max_seq_len_source, max_seq_len_target = args.max_seq_len
# The maximum length given by the user is the length before we add the BOS/EOS symbols
max_seq_len_source = max_seq_len_source + C.SPACE_FOR_XOS
max_seq_len_target = max_seq_len_target + C.SPACE_FOR_XOS
logger.info("Adjusting maximum length to reserve space for a BOS/EOS marker. New maximum length: (%d, %d)",
max_seq_len_source, max_seq_len_target)
with ExitStack() as exit_stack:
context = utils.determine_context(device_ids=args.device_ids,
use_cpu=args.use_cpu,
disable_device_locking=args.disable_device_locking,
lock_dir=args.lock_dir,
exit_stack=exit_stack)
if args.batch_type == C.BATCH_TYPE_SENTENCE:
check_condition(args.batch_size % len(context) == 0, "When using multiple devices the batch size must be "
"divisible by the number of devices. Choose a batch "
"size that is a multiple of %d." % len(context))
logger.info("Training Device(s): %s", ", ".join(str(c) for c in context))
utils.seed_rngs(args.seed, ctx=context)
train_iter, eval_iter, config_data, source_vocabs, target_vocabs = create_data_iters_and_vocabs(
args=args,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
shared_vocab=use_shared_vocab(args),
resume_training=resume_training,
output_folder=output_folder)
if max_seq_len_source != config_data.max_seq_len_source:
logger.info("Maximum source length determined by prepared data. Using %d instead of %d",
config_data.max_seq_len_source, max_seq_len_source)
max_seq_len_source = config_data.max_seq_len_source
if max_seq_len_target != config_data.max_seq_len_target:
logger.info("Maximum target length determined by prepared data. Using %d instead of %d",
config_data.max_seq_len_target, max_seq_len_target)
max_seq_len_target = config_data.max_seq_len_target
# Dump the vocabularies if we're just starting up
if not resume_training:
vocab.save_source_vocabs(source_vocabs, output_folder)
vocab.save_target_vocabs(target_vocabs, output_folder)
source_vocab_sizes = [len(v) for v in source_vocabs]
target_vocab_sizes = [len(v) for v in target_vocabs]
logger.info('Vocabulary sizes: source=[%s] target=[%s]',
'|'.join([str(size) for size in source_vocab_sizes]),
'|'.join([str(size) for size in target_vocab_sizes]))
model_config = create_model_config(args=args,
source_vocab_sizes=source_vocab_sizes,
target_vocab_sizes=target_vocab_sizes,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
config_data=config_data)
training_model = model.SockeyeModel(
model_config,
train_decoder_only=args.fixed_param_strategy == C.FIXED_PARAM_STRATEGY_ALL_EXCEPT_DECODER)
# Handle options that override training settings
trainer_config = training.TrainerConfig(
output_dir=args.output,
early_stopping_metric=args.optimized_metric,
max_params_files_to_keep=args.keep_last_params,
keep_initializations=args.keep_initializations,
max_params_files_to_cache=args.cache_last_best_params,
cache_strategy=args.cache_strategy,
cache_metric=args.cache_metric,
checkpoint_interval=args.checkpoint_interval,
max_num_checkpoint_not_improved=args.max_num_checkpoint_not_improved,
checkpoint_improvement_threshold=args.checkpoint_improvement_threshold,
max_checkpoints=args.max_checkpoints,
min_samples=args.min_samples,
max_samples=args.max_samples,
min_updates=args.min_updates,
max_updates=args.max_updates,
min_epochs=args.min_num_epochs,
max_epochs=args.max_num_epochs,
max_seconds=args.max_seconds,
update_interval=args.update_interval,
stop_training_on_decoder_failure=args.stop_training_on_decoder_failure
)
if trainer_config.min_epochs is not None and trainer_config.max_epochs is not None:
check_condition(trainer_config.min_epochs <= trainer_config.max_epochs,
"Minimum number of epochs must be smaller than maximum number of epochs")
optimizer_config = create_optimizer_config(args)
training_model.initialize(optimizer_config.initializer, ctx=context)
#training_model.save_parameters(os.path.join(args.output, 'params.init'))
if args.params is not None: # load existing parameters if present
training_model.load_parameters(filename=args.params,
ctx=context,
allow_missing=args.allow_missing_params or model_config.lhuc,
ignore_extra=args.ignore_extra_params,
cast_dtype=True,
dtype_source='current')
params = training_model.collect_params()
# set grad_req for fixed params
params = set_grad_req_for_fixed_params(config=model_config,
params=params,
fixed_param_names=args.fixed_param_names,
fixed_param_strategy=args.fixed_param_strategy)
# When using Horovod, synchronize the parameter initialization point
# across all workers by broadcasting worker 0's values. This is not
# required when resuming training as synchronized training states
# already exist.
if horovod_mpi.using_horovod() and not resume_training:
for ctx in context:
with mx.Context(ctx):
horovod_mpi.hvd.broadcast_parameters(params, root_rank=0)
if args.dtype == C.DTYPE_FP16:
training_model.cast(C.DTYPE_FP16)
utils.log_parameters(params)
# set grad_req to 'add' for trainable parameters
if args.update_interval > 1:
for name, param in params.items():
if param.grad_req != 'null':
param.grad_req = 'add'
kvstore = mx.kvstore.create(args.kvstore)
if horovod_mpi.using_horovod():
# Horovod provides a trainer that subclasses gluon.Trainer and uses
# allreduce to collect averaged gradients across all workers for
# each update.
gluon_trainer = horovod_mpi.hvd.DistributedTrainer(params,
optimizer_config.name,
optimizer_config.params)
else:
gluon_trainer = gluon.Trainer(params,
optimizer_config.name,
optimizer_config.params,
kvstore=kvstore,
update_on_kvstore=False if using_amp else None)
if using_amp:
amp.init_trainer(gluon_trainer)
# AMP does not allow passing args when creating the loss scaler, so
# we set them immediately after calling init.
gluon_trainer._amp_loss_scaler._scale_seq_len = args.amp_scale_interval # pylint: disable=no-member
losses = create_losses(args, all_num_classes=target_vocab_sizes)
hybridize = not args.no_hybridization
if hybridize:
training_model.hybridize(static_alloc=True)
if not using_amp:
# Do not hybridize losses when using AMP. Dynamic loss scaling
# requires adjusting SoftmaxOutput's grad_rescale value
# throughout training, which is not possible when using the
# Symbol API.
for lf in losses:
lf.hybridize(static_alloc=True)
trainer = training.GluonEarlyStoppingTrainer(
config=trainer_config,
optimizer_config=optimizer_config,
sockeye_model=training_model,
trainer=gluon_trainer,
loss_functions=losses,
context=context,
dtype=args.dtype,
using_amp=using_amp,
custom_metrics_logger=custom_metrics_logger,
checkpoint_callback=checkpoint_callback
)
cp_decoder = create_checkpoint_decoder(args, exit_stack, context,
training_model, source_vocabs, target_vocabs, hybridize=hybridize)
training_state = trainer.fit(train_iter=train_iter, validation_iter=eval_iter, checkpoint_decoder=cp_decoder)
return training_state