in kilt/readers/t5/base_transformer.py [0:0]
def generic_train(model: BaseTransformer, args: argparse.Namespace):
# init model
set_seed(args)
if (
os.path.exists(args.output_dir)
and os.listdir(args.output_dir)
and args.do_train
):
raise ValueError(
"Output directory ({}) already exists and is not empty.".format(
args.output_dir
)
)
checkpoint_callback = pl.callbacks.ModelCheckpoint(
filepath=args.output_dir,
prefix="checkpoint",
monitor="val_loss",
mode="min",
save_top_k=5,
)
train_params = dict(
accumulate_grad_batches=args.gradient_accumulation_steps,
gpus=args.n_gpu,
max_epochs=args.num_train_epochs,
early_stop_callback=False,
gradient_clip_val=args.max_grad_norm,
checkpoint_callback=checkpoint_callback,
callbacks=[LoggingCallback()],
)
if args.fp16:
train_params["use_amp"] = args.fp16
train_params["amp_level"] = args.fp16_opt_level
if args.n_tpu_cores > 0:
global xm
train_params["num_tpu_cores"] = args.n_tpu_cores
train_params["gpus"] = 0
if args.n_gpu > 1:
train_params["distributed_backend"] = "ddp"
trainer = pl.Trainer(**train_params)
if args.do_train:
trainer.fit(model)
return trainer