def train()

in self-training-text-classification/finetuning.py [0:0]


def train(args, accelerator, model, tokenizer, train_dataloader, optimizer, lr_scheduler, eval_dataloader=None):
    """Train a model on the given training data."""

    total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps

    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", args.num_examples[Split.TRAIN.value])
    logger.info("  Instantaneous batch size per device = %d", args.per_device_train_batch_size)
    logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_batch_size)
    logger.info("  Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", args.max_steps)

    # Only show the progress bar once on each machine.
    progress_bar = tqdm(range(args.max_steps), disable=not accelerator.is_local_main_process)

    checkpoints = None
    eval_results = None
    best_checkpoint = None
    best_eval_result = None
    early_stopping_patience_counter = 0
    should_training_stop = False
    epoch = 0
    completed_steps = 0
    train_loss = 0.0
    model.zero_grad()

    for _ in range(args.num_train_epochs):
        epoch += 1
        model.train()
        for step, batch in enumerate(train_dataloader):
            outputs = model(**batch)
            loss = outputs.loss
            loss = loss / args.gradient_accumulation_steps
            accelerator.backward(loss)
            train_loss += loss.item()

            if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                progress_bar.update(1)
                completed_steps += 1

                # Evaluate during training
                if (
                    eval_dataloader is not None
                    and args.eval_strategy == IntervalStrategy.STEPS.value
                    and args.eval_steps > 0
                    and completed_steps % args.eval_steps == 0
                ):
                    accelerator.wait_for_everyone()
                    new_checkpoint = f"checkpoint-{IntervalStrategy.STEPS.value}-{completed_steps}"
                    new_eval_result = evaluate(args, accelerator, eval_dataloader, "eval", model, new_checkpoint)[
                        args.eval_metric
                    ]
                    logger.info(
                        "Evaluation result at step %d: %s = %f", completed_steps, args.eval_metric, new_eval_result
                    )
                    if checkpoints is None:
                        checkpoints = np.array([new_checkpoint])
                        eval_results = np.array([new_eval_result])
                        best_checkpoint = new_checkpoint
                        best_eval_result = new_eval_result
                    else:
                        if new_eval_result - best_eval_result > args.early_stopping_threshold:
                            best_checkpoint = new_checkpoint
                            best_eval_result = new_eval_result
                            early_stopping_patience_counter = 0
                        else:
                            if new_eval_result == best_eval_result:
                                best_checkpoint = new_checkpoint
                                best_eval_result = new_eval_result
                            early_stopping_patience_counter += 1

                        if early_stopping_patience_counter >= args.early_stopping_patience:
                            should_training_stop = True

                        checkpoints = np.append(checkpoints, [new_checkpoint], axis=0)
                        eval_results = np.append(eval_results, [new_eval_result], axis=0)
                        sorted_ids = np.argsort(eval_results)
                        eval_results = eval_results[sorted_ids]
                        checkpoints = checkpoints[sorted_ids]

                    if len(checkpoints) > args.keep_checkpoint_max:
                        # Delete the current worst checkpoint
                        checkpoint_to_remove, *checkpoints = checkpoints
                        eval_results = eval_results[1:]
                        if checkpoint_to_remove != new_checkpoint:
                            if accelerator.is_main_process:
                                shutil.rmtree(os.path.join(args.output_dir, checkpoint_to_remove), ignore_errors=True)
                            accelerator.wait_for_everyone()

                    if new_checkpoint in checkpoints:
                        # Save model checkpoint
                        checkpoint_output_dir = os.path.join(args.output_dir, new_checkpoint)
                        if accelerator.is_main_process:
                            if not os.path.exists(checkpoint_output_dir):
                                os.makedirs(checkpoint_output_dir)
                        accelerator.wait_for_everyone()
                        unwrapped_model = accelerator.unwrap_model(model)
                        unwrapped_model.save_pretrained(checkpoint_output_dir, save_function=accelerator.save)
                        if accelerator.is_main_process:
                            tokenizer.save_pretrained(checkpoint_output_dir)
                            logger.info("Saving model checkpoint to %s", checkpoint_output_dir)

            if completed_steps >= args.max_steps:
                break

            if should_training_stop:
                break

        # Evaluate during training
        if eval_dataloader is not None and args.eval_strategy == IntervalStrategy.EPOCH.value:
            accelerator.wait_for_everyone()
            new_checkpoint = f"checkpoint-{IntervalStrategy.EPOCH.value}-{epoch}"
            new_eval_result = evaluate(args, accelerator, eval_dataloader, "eval", model, new_checkpoint)[
                args.eval_metric
            ]
            logger.info("Evaluation result at epoch %d: %s = %f", epoch, args.eval_metric, new_eval_result)

            if checkpoints is None:
                checkpoints = np.array([new_checkpoint])
                eval_results = np.array([new_eval_result])
                best_checkpoint = new_checkpoint
                best_eval_result = new_eval_result
            else:
                if new_eval_result - best_eval_result > args.early_stopping_threshold:
                    best_checkpoint = new_checkpoint
                    best_eval_result = new_eval_result
                    early_stopping_patience_counter = 0
                else:
                    if new_eval_result == best_eval_result:
                        best_checkpoint = new_checkpoint
                        best_eval_result = new_eval_result
                    early_stopping_patience_counter += 1

                if early_stopping_patience_counter >= args.early_stopping_patience:
                    should_training_stop = True

                checkpoints = np.append(checkpoints, [new_checkpoint], axis=0)
                eval_results = np.append(eval_results, [new_eval_result], axis=0)
                sorted_ids = np.argsort(eval_results)
                eval_results = eval_results[sorted_ids]
                checkpoints = checkpoints[sorted_ids]

            if len(checkpoints) > args.keep_checkpoint_max:
                # Delete the current worst checkpoint
                checkpoint_to_remove, *checkpoints = checkpoints
                eval_results = eval_results[1:]
                if checkpoint_to_remove != new_checkpoint:
                    if accelerator.is_main_process:
                        shutil.rmtree(os.path.join(args.output_dir, checkpoint_to_remove), ignore_errors=True)
                    accelerator.wait_for_everyone()

            if new_checkpoint in checkpoints:
                # Save model checkpoint
                checkpoint_output_dir = os.path.join(args.output_dir, new_checkpoint)
                if accelerator.is_main_process:
                    if not os.path.exists(checkpoint_output_dir):
                        os.makedirs(checkpoint_output_dir)
                accelerator.wait_for_everyone()
                unwrapped_model = accelerator.unwrap_model(model)
                unwrapped_model.save_pretrained(checkpoint_output_dir, save_function=accelerator.save)
                if accelerator.is_main_process:
                    tokenizer.save_pretrained(checkpoint_output_dir)
                    logger.info("Saving model checkpoint to %s", checkpoint_output_dir)

        if completed_steps >= args.max_steps:
            break

        if should_training_stop:
            break

    if best_checkpoint is not None:
        # Save the best checkpoint
        logger.info("Best checkpoint: %s", best_checkpoint)
        logger.info("Best evaluation result: %s = %f", args.eval_metric, best_eval_result)
        best_checkpoint_output_dir = os.path.join(args.output_dir, best_checkpoint)
        if accelerator.is_main_process:
            shutil.move(best_checkpoint_output_dir, os.path.join(args.output_dir, "best-checkpoint"))
            shutil.rmtree(best_checkpoint_output_dir, ignore_errors=True)
        accelerator.wait_for_everyone()

    else:
        # Assume that the last checkpoint is the best checkpoint and save it
        checkpoint_output_dir = os.path.join(args.output_dir, "best-checkpoint")
        if not os.path.exists(checkpoint_output_dir):
            os.makedirs(checkpoint_output_dir)

        accelerator.wait_for_everyone()
        unwrapped_model = accelerator.unwrap_model(model)
        unwrapped_model.save_pretrained(checkpoint_output_dir, save_function=accelerator.save)
        if accelerator.is_main_process:
            tokenizer.save_pretrained(checkpoint_output_dir)
            logger.info("Saving model checkpoint to %s", checkpoint_output_dir)
    return completed_steps, train_loss / completed_steps