def train()

in sockeye/train_pt.py [0:0]


def train(args: argparse.Namespace, custom_metrics_logger: Optional[Callable] = None,
          checkpoint_callback: Optional[Callable] = None) -> training_pt.TrainState:
    """
    :param custom_metrics_logger: Optional custom metrics logging function. If supplied, takes care of metrics produced
                                  during training in a custom way. It should accept a list or a dictionary of
                                  (metric name, metric value) pairs, and an optional global_step/checkpoint parameter.
    :param checkpoint_callback: An optional callback function (int -> None). The function will be called
                                each time a checkpoint has been reached
    """

    if args.dist:
        torch.distributed.init_process_group(torch.distributed.Backend.GLOO if args.use_cpu
                                             else torch.distributed.Backend.NCCL)

    if args.dry_run:
        # Modify arguments so that we write to a temporary directory and
        # perform 0 training iterations
        temp_dir = tempfile.TemporaryDirectory()  # Will be automatically removed
        args.output = temp_dir.name
        args.max_updates = 0

    check_arg_compatibility(args)
    output_folder = os.path.abspath(args.output)
    resume_training = check_resume(args, output_folder)

    # In distributed mode, multiple workers (instances of sockeye.train) are
    # launched via torchrun. Each worker has a unique rank. Worker 0 is the
    # primary worker that writes files and makes authoritative training
    # decisions (ex: whether a checkpoint improves). Workers 1+ are secondary
    # workers that run parallel training steps and send gradients to the primary
    # worker (but don't output anything other than log files).
    logfile = os.path.join(output_folder, C.LOG_NAME)
    console_level = None
    if not utils.is_primary_worker():
        logfile = os.path.join(output_folder, C.DIST_SECONDARY_WORKERS_LOGDIR,
                               f'{torch.distributed.get_rank()}.{C.LOG_NAME}')
        # If requested, suppress console output for secondary workers
        if args.quiet_secondary_workers:
            args.quiet = True
        console_level = args.loglevel_secondary_workers

    setup_main_logger(file_logging=not args.no_logfile,
                      console=not args.quiet,
                      path=logfile,
                      level=args.loglevel,
                      console_level=console_level)
    utils.log_basic_info(args)
    if utils.is_primary_worker():
        arguments.save_args(args, os.path.join(output_folder, C.ARGS_STATE_NAME))

    max_seq_len_source, max_seq_len_target = args.max_seq_len
    # The maximum length given by the user is the length before we add the BOS/EOS symbols
    max_seq_len_source = max_seq_len_source + C.SPACE_FOR_XOS
    max_seq_len_target = max_seq_len_target + C.SPACE_FOR_XOS
    logger.info("Adjusting maximum length to reserve space for a BOS/EOS marker. New maximum length: (%d, %d)",
                max_seq_len_source, max_seq_len_target)

    device = torch.device('cpu') if args.use_cpu \
             else torch.device('cuda', utils.get_local_rank()) if utils.is_distributed() \
             else torch.device('cuda', args.device_id)
    if not args.use_cpu:
        # Ensure that GPU operations use the correct device by default
        torch.cuda.set_device(device)
    logger.info(f'Training Device: {device}')
    utils.seed_rngs(args.seed)

    train_iter, eval_iter, config_data, source_vocabs, target_vocabs = create_data_iters_and_vocabs(
        args=args,
        max_seq_len_source=max_seq_len_source,
        max_seq_len_target=max_seq_len_target,
        shared_vocab=use_shared_vocab(args),
        resume_training=resume_training,
        output_folder=output_folder)

    if max_seq_len_source != config_data.max_seq_len_source:
        logger.info("Maximum source length determined by prepared data. Using %d instead of %d",
                    config_data.max_seq_len_source, max_seq_len_source)
        max_seq_len_source = config_data.max_seq_len_source
    if max_seq_len_target != config_data.max_seq_len_target:
        logger.info("Maximum target length determined by prepared data. Using %d instead of %d",
                    config_data.max_seq_len_target, max_seq_len_target)
        max_seq_len_target = config_data.max_seq_len_target

    # Dump the vocabularies if we're just starting up
    if utils.is_primary_worker() and not resume_training:
        vocab.save_source_vocabs(source_vocabs, output_folder)
        vocab.save_target_vocabs(target_vocabs, output_folder)

    source_vocab_sizes = [len(v) for v in source_vocabs]
    target_vocab_sizes = [len(v) for v in target_vocabs]
    logger.info('Vocabulary sizes: source=[%s] target=[%s]',
                '|'.join([str(size) for size in source_vocab_sizes]),
                '|'.join([str(size) for size in target_vocab_sizes]))

    model_config = create_model_config(args=args,
                                       source_vocab_sizes=source_vocab_sizes,
                                       target_vocab_sizes=target_vocab_sizes,
                                       max_seq_len_source=max_seq_len_source,
                                       max_seq_len_target=max_seq_len_target,
                                       config_data=config_data)

    # Handle options that override training settings
    trainer_config = training_pt.TrainerConfig(output_dir=args.output,
                                               early_stopping_metric=args.optimized_metric,
                                               max_params_files_to_keep=args.keep_last_params,
                                               keep_initializations=args.keep_initializations,
                                               max_params_files_to_cache=args.cache_last_best_params,
                                               cache_strategy=args.cache_strategy,
                                               cache_metric=args.cache_metric,
                                               checkpoint_interval=args.checkpoint_interval,
                                               max_num_checkpoint_not_improved=args.max_num_checkpoint_not_improved,
                                               checkpoint_improvement_threshold=args.checkpoint_improvement_threshold,
                                               max_checkpoints=args.max_checkpoints,
                                               min_samples=args.min_samples,
                                               max_samples=args.max_samples,
                                               min_updates=args.min_updates,
                                               max_updates=args.max_updates,
                                               min_epochs=args.min_num_epochs,
                                               max_epochs=args.max_num_epochs,
                                               max_seconds=args.max_seconds,
                                               update_interval=args.update_interval,
                                               stop_training_on_decoder_failure=args.stop_training_on_decoder_failure)
    if trainer_config.min_epochs is not None and trainer_config.max_epochs is not None:
        check_condition(trainer_config.min_epochs <= trainer_config.max_epochs,
                        "Minimum number of epochs must be smaller than maximum number of epochs")

    optimizer_config = create_optimizer_config(args)

    sockeye_model = model_pt.PyTorchSockeyeModel(
        model_config,
        train_decoder_only=args.fixed_param_strategy == C.FIXED_PARAM_STRATEGY_ALL_EXCEPT_DECODER)
    sockeye_model.to(device)
    sockeye_model.apply(model_pt.initialize_parameters)

    # Load starting parameters if specified
    if args.params is not None:
        sockeye_model.load_parameters(filename=args.params,
                                      device=device,
                                      allow_missing=args.allow_missing_params or model_config.lhuc,
                                      ignore_extra=args.ignore_extra_params)

    unset_requires_grad_for_fixed_params(config=model_config,
                                         params=dict(sockeye_model.named_parameters()),
                                         fixed_param_names=args.fixed_param_names,
                                         fixed_param_strategy=args.fixed_param_strategy)

    utils.log_parameters_pt(sockeye_model)

    optimizer, zero_grad_kwargs = optimizers.get_optimizer(sockeye_model, optimizer_config)

    # This starts as a reference to the original Sockeye model. It is
    # sequentially transformed/wrapped to produce the model instance used for
    # training.
    training_model = sockeye_model  # type: torch.nn.Module

    if args.apex_amp:
        try:
            import apex.amp
        except ImportError:
            logger.error('Cannot import NVIDIA Apex AMP. Please install Apex: https://github.com/NVIDIA/apex')
            sys.exit(1)
        # Optimization level 2 runs the entire model in FP16 mode with FP32
        # master weights and loss scaling. See:
        # https://nvidia.github.io/apex/amp.html#o2-almost-fp16-mixed-precision
        training_model, optimizer = apex.amp.initialize(training_model, optimizer, opt_level='O2')

    logger.info('Tracing model on validation batch')
    batch = eval_iter.next().load(device=device)  # pylint: disable=not-callable
    # When using AMP, turn on autocasting when tracing the model so that
    # dtypes will match during AMP training. Disable the weight cache for
    # compatibility with tracing. See:
    # https://github.com/pytorch/pytorch/pull/63552
    with torch.cuda.amp.autocast(cache_enabled=False) if args.amp else utils.no_context():  # type: ignore
        training_model = torch.jit.trace(training_model, (batch.source, batch.source_length,
                                                          batch.target, batch.target_length), strict=False)
    eval_iter.reset()

    if utils.is_distributed():
        # In distributed mode, wrap the traced model with a distributed
        # data-parallel model that shares (averages) gradients with models
        # in other worker processes.
        training_model = torch.nn.parallel.DistributedDataParallel(training_model,
                                                                   device_ids=None if args.use_cpu else [device],
                                                                   output_device=None if args.use_cpu else device)

    losses = create_losses(args, all_num_classes=target_vocab_sizes)

    trainer = training_pt.PyTorchEarlyStoppingTrainer(
        config=trainer_config,
        optimizer_config=optimizer_config,
        sockeye_model=sockeye_model,
        training_model=training_model,
        optimizer=optimizer,
        zero_grad_kwargs=zero_grad_kwargs,
        loss_functions=losses,
        device=device,
        using_amp=args.amp,
        using_apex_amp=args.apex_amp,
        custom_metrics_logger=custom_metrics_logger,
        checkpoint_callback=checkpoint_callback)

    # Only primary worker runs checkpoint decoder
    checkpoint_decoder = None
    if utils.is_primary_worker():
        checkpoint_decoder = create_checkpoint_decoder(args, device, sockeye_model, source_vocabs, target_vocabs)

    training_state = trainer.fit(train_iter=train_iter, validation_iter=eval_iter,
                                 checkpoint_decoder=checkpoint_decoder)

    return training_state