def train()

in sockeye/train.py [0:0]


def train(args: argparse.Namespace, custom_metrics_logger: Optional[Callable] = None,
          checkpoint_callback: Optional[Callable] = None) -> training.TrainState:
    """
    :param custom_metrics_logger: Optional custom metrics logging function. If supplied, takes care of metrics produced
                                  during training in a custom way. It should accept a list or a dictionary of
                                  (metric name, metric value) pairs, and an optional global_step/checkpoint parameter.
    :param checkpoint_callback: An optional callback function (int -> None). The function will be called
+                                each time a checkpoint has been reached
    """

    if args.dry_run:
        # Modify arguments so that we write to a temporary directory and
        # perform 0 training iterations
        temp_dir = tempfile.TemporaryDirectory()  # Will be automatically removed
        args.output = temp_dir.name
        args.max_updates = 0

    # Automatic Mixed Precision training
    using_amp = False
    if args.amp:
        using_amp = True
        amp.init()

    # When using Horovod, multiple workers (instances of sockeye.train) are
    # launched via MPI.  Each worker has a rank (unique among all workers in the
    # training run) and a local rank (unique on the current host).  For example,
    # running on 2 hosts with 4 slots each will assign ranks 0-7 and local ranks
    # 0-3.
    console_level = None
    if args.horovod:
        if horovod_mpi.hvd is None or horovod_mpi.MPI is None:
            raise RuntimeError('Horovod training requires the following packages to be installed: horovod mpi4py')
        # Unless explicitly set otherwise, use NCCL for same-host
        # allreduce/allgather and MPI for cross-host allreduce/allgather.
        if C.HOROVOD_HIERARCHICAL_ALLREDUCE not in os.environ:
            os.environ[C.HOROVOD_HIERARCHICAL_ALLREDUCE] = '1'
        if C.HOROVOD_HIERARCHICAL_ALLGATHER not in os.environ:
            os.environ[C.HOROVOD_HIERARCHICAL_ALLGATHER] = '1'
        horovod_mpi.hvd.init()
        # Each worker uses a separate output directory.  The primary worker
        # (rank 0) writes files to the root of the output directory (standard
        # behavior).  Secondary workers write files to rank-named
        # sub-directories.
        if horovod_mpi.hvd.rank() > 0:
            args.output = os.path.join(args.output, C.HOROVOD_SECONDARY_WORKERS_DIRNAME, str(horovod_mpi.hvd.rank()))
            # Do not keep redundant copies of the checkpoint history
            args.keep_last_params = 1
            # If requested, suppress console output for secondary workers
            if args.quiet_secondary_workers:
                args.quiet = True
            console_level = args.loglevel_secondary_workers

    check_arg_compatibility(args)
    output_folder = os.path.abspath(args.output)
    resume_training = check_resume(args, output_folder)

    setup_main_logger(file_logging=not args.no_logfile,
                      console=not args.quiet,
                      path=os.path.join(output_folder, C.LOG_NAME),
                      level=args.loglevel,
                      console_level=console_level)
    utils.log_basic_info(args)
    arguments.save_args(args, os.path.join(output_folder, C.ARGS_STATE_NAME))

    max_seq_len_source, max_seq_len_target = args.max_seq_len
    # The maximum length given by the user is the length before we add the BOS/EOS symbols
    max_seq_len_source = max_seq_len_source + C.SPACE_FOR_XOS
    max_seq_len_target = max_seq_len_target + C.SPACE_FOR_XOS
    logger.info("Adjusting maximum length to reserve space for a BOS/EOS marker. New maximum length: (%d, %d)",
                max_seq_len_source, max_seq_len_target)

    with ExitStack() as exit_stack:
        context = utils.determine_context(device_ids=args.device_ids,
                                          use_cpu=args.use_cpu,
                                          disable_device_locking=args.disable_device_locking,
                                          lock_dir=args.lock_dir,
                                          exit_stack=exit_stack)
        if args.batch_type == C.BATCH_TYPE_SENTENCE:
            check_condition(args.batch_size % len(context) == 0, "When using multiple devices the batch size must be "
                                                                 "divisible by the number of devices. Choose a batch "
                                                                 "size that is a multiple of %d." % len(context))
        logger.info("Training Device(s): %s", ", ".join(str(c) for c in context))

        utils.seed_rngs(args.seed, ctx=context)

        train_iter, eval_iter, config_data, source_vocabs, target_vocabs = create_data_iters_and_vocabs(
            args=args,
            max_seq_len_source=max_seq_len_source,
            max_seq_len_target=max_seq_len_target,
            shared_vocab=use_shared_vocab(args),
            resume_training=resume_training,
            output_folder=output_folder)

        if max_seq_len_source != config_data.max_seq_len_source:
            logger.info("Maximum source length determined by prepared data. Using %d instead of %d",
                        config_data.max_seq_len_source, max_seq_len_source)
            max_seq_len_source = config_data.max_seq_len_source
        if max_seq_len_target != config_data.max_seq_len_target:
            logger.info("Maximum target length determined by prepared data. Using %d instead of %d",
                        config_data.max_seq_len_target, max_seq_len_target)
            max_seq_len_target = config_data.max_seq_len_target

        # Dump the vocabularies if we're just starting up
        if not resume_training:
            vocab.save_source_vocabs(source_vocabs, output_folder)
            vocab.save_target_vocabs(target_vocabs, output_folder)

        source_vocab_sizes = [len(v) for v in source_vocabs]
        target_vocab_sizes = [len(v) for v in target_vocabs]
        logger.info('Vocabulary sizes: source=[%s] target=[%s]',
                    '|'.join([str(size) for size in source_vocab_sizes]),
                    '|'.join([str(size) for size in target_vocab_sizes]))

        model_config = create_model_config(args=args,
                                           source_vocab_sizes=source_vocab_sizes,
                                           target_vocab_sizes=target_vocab_sizes,
                                           max_seq_len_source=max_seq_len_source,
                                           max_seq_len_target=max_seq_len_target,
                                           config_data=config_data)

        training_model = model.SockeyeModel(
            model_config,
            train_decoder_only=args.fixed_param_strategy == C.FIXED_PARAM_STRATEGY_ALL_EXCEPT_DECODER)

        # Handle options that override training settings
        trainer_config = training.TrainerConfig(
            output_dir=args.output,
            early_stopping_metric=args.optimized_metric,
            max_params_files_to_keep=args.keep_last_params,
            keep_initializations=args.keep_initializations,
            max_params_files_to_cache=args.cache_last_best_params,
            cache_strategy=args.cache_strategy,
            cache_metric=args.cache_metric,
            checkpoint_interval=args.checkpoint_interval,
            max_num_checkpoint_not_improved=args.max_num_checkpoint_not_improved,
            checkpoint_improvement_threshold=args.checkpoint_improvement_threshold,
            max_checkpoints=args.max_checkpoints,
            min_samples=args.min_samples,
            max_samples=args.max_samples,
            min_updates=args.min_updates,
            max_updates=args.max_updates,
            min_epochs=args.min_num_epochs,
            max_epochs=args.max_num_epochs,
            max_seconds=args.max_seconds,
            update_interval=args.update_interval,
            stop_training_on_decoder_failure=args.stop_training_on_decoder_failure
        )
        if trainer_config.min_epochs is not None and trainer_config.max_epochs is not None:
            check_condition(trainer_config.min_epochs <= trainer_config.max_epochs,
                            "Minimum number of epochs must be smaller than maximum number of epochs")

        optimizer_config = create_optimizer_config(args)
        training_model.initialize(optimizer_config.initializer, ctx=context)
        #training_model.save_parameters(os.path.join(args.output, 'params.init'))

        if args.params is not None:  # load existing parameters if present
            training_model.load_parameters(filename=args.params,
                                           ctx=context,
                                           allow_missing=args.allow_missing_params or model_config.lhuc,
                                           ignore_extra=args.ignore_extra_params,
                                           cast_dtype=True,
                                           dtype_source='current')
        params = training_model.collect_params()
        # set grad_req for fixed params
        params = set_grad_req_for_fixed_params(config=model_config,
                                               params=params,
                                               fixed_param_names=args.fixed_param_names,
                                               fixed_param_strategy=args.fixed_param_strategy)

        # When using Horovod, synchronize the parameter initialization point
        # across all workers by broadcasting worker 0's values.  This is not
        # required when resuming training as synchronized training states
        # already exist.
        if horovod_mpi.using_horovod() and not resume_training:
            for ctx in context:
                with mx.Context(ctx):
                    horovod_mpi.hvd.broadcast_parameters(params, root_rank=0)

        if args.dtype == C.DTYPE_FP16:
            training_model.cast(C.DTYPE_FP16)
        utils.log_parameters(params)

        # set grad_req to 'add' for trainable parameters
        if args.update_interval > 1:
            for name, param in params.items():
                if param.grad_req != 'null':
                    param.grad_req = 'add'

        kvstore = mx.kvstore.create(args.kvstore)

        if horovod_mpi.using_horovod():
            # Horovod provides a trainer that subclasses gluon.Trainer and uses
            # allreduce to collect averaged gradients across all workers for
            # each update.
            gluon_trainer = horovod_mpi.hvd.DistributedTrainer(params,
                                                               optimizer_config.name,
                                                               optimizer_config.params)
        else:
            gluon_trainer = gluon.Trainer(params,
                                          optimizer_config.name,
                                          optimizer_config.params,
                                          kvstore=kvstore,
                                          update_on_kvstore=False if using_amp else None)

        if using_amp:
            amp.init_trainer(gluon_trainer)
            # AMP does not allow passing args when creating the loss scaler, so
            # we set them immediately after calling init.
            gluon_trainer._amp_loss_scaler._scale_seq_len = args.amp_scale_interval  # pylint: disable=no-member

        losses = create_losses(args, all_num_classes=target_vocab_sizes)

        hybridize = not args.no_hybridization
        if hybridize:
            training_model.hybridize(static_alloc=True)
            if not using_amp:
                # Do not hybridize losses when using AMP.  Dynamic loss scaling
                # requires adjusting SoftmaxOutput's grad_rescale value
                # throughout training, which is not possible when using the
                # Symbol API.
                for lf in losses:
                    lf.hybridize(static_alloc=True)

        trainer = training.GluonEarlyStoppingTrainer(
            config=trainer_config,
            optimizer_config=optimizer_config,
            sockeye_model=training_model,
            trainer=gluon_trainer,
            loss_functions=losses,
            context=context,
            dtype=args.dtype,
            using_amp=using_amp,
            custom_metrics_logger=custom_metrics_logger,
            checkpoint_callback=checkpoint_callback
        )

        cp_decoder = create_checkpoint_decoder(args, exit_stack, context,
                                            training_model, source_vocabs, target_vocabs, hybridize=hybridize)

        training_state = trainer.fit(train_iter=train_iter, validation_iter=eval_iter, checkpoint_decoder=cp_decoder)
        return training_state