in src/engine/step4/model_dev/t5_training.py [0:0]
def train(args):
"""
Fine-tunes a pretrained T5 models based on the given parameters.
Args:
args(argparse.Namespace): Model parameters and other input arguments.
Returns:
None
"""
model = T5FineTuner(args)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
monitor="val_loss",
filename="model_checkpoint-{epoch:02d}-{val_loss:.3f}",
save_top_k=-1,
save_last=True,
)
tb_logger = pl_loggers.TensorBoardLogger(args.output_dir)
train_params = dict(
accumulate_grad_batches=args.gradient_accumulation_steps,
gpus=-1,
max_epochs=args.num_train_epochs,
precision=16 if args.fp_16 else 32,
amp_level=args.opt_level,
resume_from_checkpoint=args.resume_from_checkpoint,
gradient_clip_val=args.max_grad_norm,
checkpoint_callback=checkpoint_callback,
val_check_interval=args.val_check_interval,
logger=tb_logger,
callbacks=[LoggingCallback()],
accelerator="dp",
)
trainer = pl.Trainer(**train_params)
trainer.fit(model)
# After model has been trained, save its state into output_data_dir
with open(
os.path.join(
args.output_dir, "model_{0}.pth".format(time.strftime("%Y%m%d-%H%M%S"))
),
"wb",
) as f:
torch.save(model.state_dict(), f)