in self-training-text-classification/finetuning.py [0:0]
def train(args, accelerator, model, tokenizer, train_dataloader, optimizer, lr_scheduler, eval_dataloader=None):
"""Train a model on the given training data."""
total_batch_size = args.per_device_train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
logger.info("***** Running training *****")
logger.info(" Num examples = %d", args.num_examples[Split.TRAIN.value])
logger.info(" Instantaneous batch size per device = %d", args.per_device_train_batch_size)
logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d", total_batch_size)
logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps)
logger.info(" Total optimization steps = %d", args.max_steps)
# Only show the progress bar once on each machine.
progress_bar = tqdm(range(args.max_steps), disable=not accelerator.is_local_main_process)
checkpoints = None
eval_results = None
best_checkpoint = None
best_eval_result = None
early_stopping_patience_counter = 0
should_training_stop = False
epoch = 0
completed_steps = 0
train_loss = 0.0
model.zero_grad()
for _ in range(args.num_train_epochs):
epoch += 1
model.train()
for step, batch in enumerate(train_dataloader):
outputs = model(**batch)
loss = outputs.loss
loss = loss / args.gradient_accumulation_steps
accelerator.backward(loss)
train_loss += loss.item()
if step % args.gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
progress_bar.update(1)
completed_steps += 1
# Evaluate during training
if (
eval_dataloader is not None
and args.eval_strategy == IntervalStrategy.STEPS.value
and args.eval_steps > 0
and completed_steps % args.eval_steps == 0
):
accelerator.wait_for_everyone()
new_checkpoint = f"checkpoint-{IntervalStrategy.STEPS.value}-{completed_steps}"
new_eval_result = evaluate(args, accelerator, eval_dataloader, "eval", model, new_checkpoint)[
args.eval_metric
]
logger.info(
"Evaluation result at step %d: %s = %f", completed_steps, args.eval_metric, new_eval_result
)
if checkpoints is None:
checkpoints = np.array([new_checkpoint])
eval_results = np.array([new_eval_result])
best_checkpoint = new_checkpoint
best_eval_result = new_eval_result
else:
if new_eval_result - best_eval_result > args.early_stopping_threshold:
best_checkpoint = new_checkpoint
best_eval_result = new_eval_result
early_stopping_patience_counter = 0
else:
if new_eval_result == best_eval_result:
best_checkpoint = new_checkpoint
best_eval_result = new_eval_result
early_stopping_patience_counter += 1
if early_stopping_patience_counter >= args.early_stopping_patience:
should_training_stop = True
checkpoints = np.append(checkpoints, [new_checkpoint], axis=0)
eval_results = np.append(eval_results, [new_eval_result], axis=0)
sorted_ids = np.argsort(eval_results)
eval_results = eval_results[sorted_ids]
checkpoints = checkpoints[sorted_ids]
if len(checkpoints) > args.keep_checkpoint_max:
# Delete the current worst checkpoint
checkpoint_to_remove, *checkpoints = checkpoints
eval_results = eval_results[1:]
if checkpoint_to_remove != new_checkpoint:
if accelerator.is_main_process:
shutil.rmtree(os.path.join(args.output_dir, checkpoint_to_remove), ignore_errors=True)
accelerator.wait_for_everyone()
if new_checkpoint in checkpoints:
# Save model checkpoint
checkpoint_output_dir = os.path.join(args.output_dir, new_checkpoint)
if accelerator.is_main_process:
if not os.path.exists(checkpoint_output_dir):
os.makedirs(checkpoint_output_dir)
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(checkpoint_output_dir, save_function=accelerator.save)
if accelerator.is_main_process:
tokenizer.save_pretrained(checkpoint_output_dir)
logger.info("Saving model checkpoint to %s", checkpoint_output_dir)
if completed_steps >= args.max_steps:
break
if should_training_stop:
break
# Evaluate during training
if eval_dataloader is not None and args.eval_strategy == IntervalStrategy.EPOCH.value:
accelerator.wait_for_everyone()
new_checkpoint = f"checkpoint-{IntervalStrategy.EPOCH.value}-{epoch}"
new_eval_result = evaluate(args, accelerator, eval_dataloader, "eval", model, new_checkpoint)[
args.eval_metric
]
logger.info("Evaluation result at epoch %d: %s = %f", epoch, args.eval_metric, new_eval_result)
if checkpoints is None:
checkpoints = np.array([new_checkpoint])
eval_results = np.array([new_eval_result])
best_checkpoint = new_checkpoint
best_eval_result = new_eval_result
else:
if new_eval_result - best_eval_result > args.early_stopping_threshold:
best_checkpoint = new_checkpoint
best_eval_result = new_eval_result
early_stopping_patience_counter = 0
else:
if new_eval_result == best_eval_result:
best_checkpoint = new_checkpoint
best_eval_result = new_eval_result
early_stopping_patience_counter += 1
if early_stopping_patience_counter >= args.early_stopping_patience:
should_training_stop = True
checkpoints = np.append(checkpoints, [new_checkpoint], axis=0)
eval_results = np.append(eval_results, [new_eval_result], axis=0)
sorted_ids = np.argsort(eval_results)
eval_results = eval_results[sorted_ids]
checkpoints = checkpoints[sorted_ids]
if len(checkpoints) > args.keep_checkpoint_max:
# Delete the current worst checkpoint
checkpoint_to_remove, *checkpoints = checkpoints
eval_results = eval_results[1:]
if checkpoint_to_remove != new_checkpoint:
if accelerator.is_main_process:
shutil.rmtree(os.path.join(args.output_dir, checkpoint_to_remove), ignore_errors=True)
accelerator.wait_for_everyone()
if new_checkpoint in checkpoints:
# Save model checkpoint
checkpoint_output_dir = os.path.join(args.output_dir, new_checkpoint)
if accelerator.is_main_process:
if not os.path.exists(checkpoint_output_dir):
os.makedirs(checkpoint_output_dir)
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(checkpoint_output_dir, save_function=accelerator.save)
if accelerator.is_main_process:
tokenizer.save_pretrained(checkpoint_output_dir)
logger.info("Saving model checkpoint to %s", checkpoint_output_dir)
if completed_steps >= args.max_steps:
break
if should_training_stop:
break
if best_checkpoint is not None:
# Save the best checkpoint
logger.info("Best checkpoint: %s", best_checkpoint)
logger.info("Best evaluation result: %s = %f", args.eval_metric, best_eval_result)
best_checkpoint_output_dir = os.path.join(args.output_dir, best_checkpoint)
if accelerator.is_main_process:
shutil.move(best_checkpoint_output_dir, os.path.join(args.output_dir, "best-checkpoint"))
shutil.rmtree(best_checkpoint_output_dir, ignore_errors=True)
accelerator.wait_for_everyone()
else:
# Assume that the last checkpoint is the best checkpoint and save it
checkpoint_output_dir = os.path.join(args.output_dir, "best-checkpoint")
if not os.path.exists(checkpoint_output_dir):
os.makedirs(checkpoint_output_dir)
accelerator.wait_for_everyone()
unwrapped_model = accelerator.unwrap_model(model)
unwrapped_model.save_pretrained(checkpoint_output_dir, save_function=accelerator.save)
if accelerator.is_main_process:
tokenizer.save_pretrained(checkpoint_output_dir)
logger.info("Saving model checkpoint to %s", checkpoint_output_dir)
return completed_steps, train_loss / completed_steps