in sockeye/train_pt.py [0:0]
def train(args: argparse.Namespace, custom_metrics_logger: Optional[Callable] = None,
checkpoint_callback: Optional[Callable] = None) -> training_pt.TrainState:
"""
:param custom_metrics_logger: Optional custom metrics logging function. If supplied, takes care of metrics produced
during training in a custom way. It should accept a list or a dictionary of
(metric name, metric value) pairs, and an optional global_step/checkpoint parameter.
:param checkpoint_callback: An optional callback function (int -> None). The function will be called
each time a checkpoint has been reached
"""
if args.dist:
torch.distributed.init_process_group(torch.distributed.Backend.GLOO if args.use_cpu
else torch.distributed.Backend.NCCL)
if args.dry_run:
# Modify arguments so that we write to a temporary directory and
# perform 0 training iterations
temp_dir = tempfile.TemporaryDirectory() # Will be automatically removed
args.output = temp_dir.name
args.max_updates = 0
check_arg_compatibility(args)
output_folder = os.path.abspath(args.output)
resume_training = check_resume(args, output_folder)
# In distributed mode, multiple workers (instances of sockeye.train) are
# launched via torchrun. Each worker has a unique rank. Worker 0 is the
# primary worker that writes files and makes authoritative training
# decisions (ex: whether a checkpoint improves). Workers 1+ are secondary
# workers that run parallel training steps and send gradients to the primary
# worker (but don't output anything other than log files).
logfile = os.path.join(output_folder, C.LOG_NAME)
console_level = None
if not utils.is_primary_worker():
logfile = os.path.join(output_folder, C.DIST_SECONDARY_WORKERS_LOGDIR,
f'{torch.distributed.get_rank()}.{C.LOG_NAME}')
# If requested, suppress console output for secondary workers
if args.quiet_secondary_workers:
args.quiet = True
console_level = args.loglevel_secondary_workers
setup_main_logger(file_logging=not args.no_logfile,
console=not args.quiet,
path=logfile,
level=args.loglevel,
console_level=console_level)
utils.log_basic_info(args)
if utils.is_primary_worker():
arguments.save_args(args, os.path.join(output_folder, C.ARGS_STATE_NAME))
max_seq_len_source, max_seq_len_target = args.max_seq_len
# The maximum length given by the user is the length before we add the BOS/EOS symbols
max_seq_len_source = max_seq_len_source + C.SPACE_FOR_XOS
max_seq_len_target = max_seq_len_target + C.SPACE_FOR_XOS
logger.info("Adjusting maximum length to reserve space for a BOS/EOS marker. New maximum length: (%d, %d)",
max_seq_len_source, max_seq_len_target)
device = torch.device('cpu') if args.use_cpu \
else torch.device('cuda', utils.get_local_rank()) if utils.is_distributed() \
else torch.device('cuda', args.device_id)
if not args.use_cpu:
# Ensure that GPU operations use the correct device by default
torch.cuda.set_device(device)
logger.info(f'Training Device: {device}')
utils.seed_rngs(args.seed)
train_iter, eval_iter, config_data, source_vocabs, target_vocabs = create_data_iters_and_vocabs(
args=args,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
shared_vocab=use_shared_vocab(args),
resume_training=resume_training,
output_folder=output_folder)
if max_seq_len_source != config_data.max_seq_len_source:
logger.info("Maximum source length determined by prepared data. Using %d instead of %d",
config_data.max_seq_len_source, max_seq_len_source)
max_seq_len_source = config_data.max_seq_len_source
if max_seq_len_target != config_data.max_seq_len_target:
logger.info("Maximum target length determined by prepared data. Using %d instead of %d",
config_data.max_seq_len_target, max_seq_len_target)
max_seq_len_target = config_data.max_seq_len_target
# Dump the vocabularies if we're just starting up
if utils.is_primary_worker() and not resume_training:
vocab.save_source_vocabs(source_vocabs, output_folder)
vocab.save_target_vocabs(target_vocabs, output_folder)
source_vocab_sizes = [len(v) for v in source_vocabs]
target_vocab_sizes = [len(v) for v in target_vocabs]
logger.info('Vocabulary sizes: source=[%s] target=[%s]',
'|'.join([str(size) for size in source_vocab_sizes]),
'|'.join([str(size) for size in target_vocab_sizes]))
model_config = create_model_config(args=args,
source_vocab_sizes=source_vocab_sizes,
target_vocab_sizes=target_vocab_sizes,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
config_data=config_data)
# Handle options that override training settings
trainer_config = training_pt.TrainerConfig(output_dir=args.output,
early_stopping_metric=args.optimized_metric,
max_params_files_to_keep=args.keep_last_params,
keep_initializations=args.keep_initializations,
max_params_files_to_cache=args.cache_last_best_params,
cache_strategy=args.cache_strategy,
cache_metric=args.cache_metric,
checkpoint_interval=args.checkpoint_interval,
max_num_checkpoint_not_improved=args.max_num_checkpoint_not_improved,
checkpoint_improvement_threshold=args.checkpoint_improvement_threshold,
max_checkpoints=args.max_checkpoints,
min_samples=args.min_samples,
max_samples=args.max_samples,
min_updates=args.min_updates,
max_updates=args.max_updates,
min_epochs=args.min_num_epochs,
max_epochs=args.max_num_epochs,
max_seconds=args.max_seconds,
update_interval=args.update_interval,
stop_training_on_decoder_failure=args.stop_training_on_decoder_failure)
if trainer_config.min_epochs is not None and trainer_config.max_epochs is not None:
check_condition(trainer_config.min_epochs <= trainer_config.max_epochs,
"Minimum number of epochs must be smaller than maximum number of epochs")
optimizer_config = create_optimizer_config(args)
sockeye_model = model_pt.PyTorchSockeyeModel(
model_config,
train_decoder_only=args.fixed_param_strategy == C.FIXED_PARAM_STRATEGY_ALL_EXCEPT_DECODER)
sockeye_model.to(device)
sockeye_model.apply(model_pt.initialize_parameters)
# Load starting parameters if specified
if args.params is not None:
sockeye_model.load_parameters(filename=args.params,
device=device,
allow_missing=args.allow_missing_params or model_config.lhuc,
ignore_extra=args.ignore_extra_params)
unset_requires_grad_for_fixed_params(config=model_config,
params=dict(sockeye_model.named_parameters()),
fixed_param_names=args.fixed_param_names,
fixed_param_strategy=args.fixed_param_strategy)
utils.log_parameters_pt(sockeye_model)
optimizer, zero_grad_kwargs = optimizers.get_optimizer(sockeye_model, optimizer_config)
# This starts as a reference to the original Sockeye model. It is
# sequentially transformed/wrapped to produce the model instance used for
# training.
training_model = sockeye_model # type: torch.nn.Module
if args.apex_amp:
try:
import apex.amp
except ImportError:
logger.error('Cannot import NVIDIA Apex AMP. Please install Apex: https://github.com/NVIDIA/apex')
sys.exit(1)
# Optimization level 2 runs the entire model in FP16 mode with FP32
# master weights and loss scaling. See:
# https://nvidia.github.io/apex/amp.html#o2-almost-fp16-mixed-precision
training_model, optimizer = apex.amp.initialize(training_model, optimizer, opt_level='O2')
logger.info('Tracing model on validation batch')
batch = eval_iter.next().load(device=device) # pylint: disable=not-callable
# When using AMP, turn on autocasting when tracing the model so that
# dtypes will match during AMP training. Disable the weight cache for
# compatibility with tracing. See:
# https://github.com/pytorch/pytorch/pull/63552
with torch.cuda.amp.autocast(cache_enabled=False) if args.amp else utils.no_context(): # type: ignore
training_model = torch.jit.trace(training_model, (batch.source, batch.source_length,
batch.target, batch.target_length), strict=False)
eval_iter.reset()
if utils.is_distributed():
# In distributed mode, wrap the traced model with a distributed
# data-parallel model that shares (averages) gradients with models
# in other worker processes.
training_model = torch.nn.parallel.DistributedDataParallel(training_model,
device_ids=None if args.use_cpu else [device],
output_device=None if args.use_cpu else device)
losses = create_losses(args, all_num_classes=target_vocab_sizes)
trainer = training_pt.PyTorchEarlyStoppingTrainer(
config=trainer_config,
optimizer_config=optimizer_config,
sockeye_model=sockeye_model,
training_model=training_model,
optimizer=optimizer,
zero_grad_kwargs=zero_grad_kwargs,
loss_functions=losses,
device=device,
using_amp=args.amp,
using_apex_amp=args.apex_amp,
custom_metrics_logger=custom_metrics_logger,
checkpoint_callback=checkpoint_callback)
# Only primary worker runs checkpoint decoder
checkpoint_decoder = None
if utils.is_primary_worker():
checkpoint_decoder = create_checkpoint_decoder(args, device, sockeye_model, source_vocabs, target_vocabs)
training_state = trainer.fit(train_iter=train_iter, validation_iter=eval_iter,
checkpoint_decoder=checkpoint_decoder)
return training_state