in modeling/main.py [0:0]
def train(args, tokenizer, model):
# set dataloader
train_dataloader = set_dataloader(args, tokenizer, 'train', 'teacher_force', data_size=args.train_size)
dev_dataloader = set_dataloader(args, tokenizer, 'dev', 'teacher_force')
dev_gen_dataloader = set_dataloader(args, tokenizer, 'dev', 'generation')
# set optimizer, lr scheduler
optimizer = AdamW(model.parameters(), lr=args.learning_rate, eps=args.adam_epsilon)
if args.use_scheduler:
t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.max_epoch
scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
else:
scheduler = None
trainer = (model, optimizer, scheduler, tokenizer)
print('Test before training!')
model.eval()
with torch.no_grad():
_ = run_one_epoch('dev', dev_dataloader, trainer, -1, 'teacher_force')
print('Start training!\n{}'.format('***'*30))
eval_step = args.eval_interval // args.train_batch_size
# score of query rewrite, corerference resolution, and joint learning (average of two)
best_score = {'best-QR': -10000, 'best-COREF': -10000, 'best-JOINT': -10000}
global_step = 0
no_improve_count = 0
for epoch in range(args.max_epoch):
t0 = time.time()
model.train()
model.zero_grad()
LOSS, match = {'bi': 0, 'lm': 0, 'mention': 0, 'reference': 0, 'total': 0}, []
iterator = enumerate(tqdm(train_dataloader, desc="Epoch {}".format(epoch), disable=args.disable_display))
if args.disable_display:
print('Training progress is not showing')
for local_step, batch in iterator:
loss, _, _, _, _, _, _ = model(input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], \
token_type_ids=batch['token_type_ids'], labels=batch['label_ids'], step=None, \
mention_labels=batch['mention_label_ids'], batch=batch, coref_links=batch['coref_label'])
for k, v in loss.items():
LOSS[k] += v.item()
global_step += 1
# update model
if loss['total'].item() != 0:
loss['total'] = loss['total'] / args.gradient_accumulation_steps
loss['total'].backward()
# accumulate gradients
if global_step % args.gradient_accumulation_steps == 0:
norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
optimizer.step()
if args.use_scheduler:
scheduler.step()
optimizer.zero_grad()
# evaluate model
if global_step % eval_step == 0 and epoch > 0:
model.eval()
with torch.no_grad():
loss = run_one_epoch('dev', dev_dataloader, trainer, epoch, 'teacher_force') # get dev loss
res = run_one_epoch('dev', dev_gen_dataloader, trainer, epoch, 'generation') # get dev result
model.train()
# save model
save_model = save_best_model(args.task, res, best_score)
if save_model:
no_improve_count = 0
else:
no_improve_count += 1
# early stop
if no_improve_count == args.no_improve_max:
print('Early stop!')
return
# get train loss
for k, v in LOSS.items():
LOSS[k] /= (local_step+1)
print_loss(epoch, 'train', LOSS, t0)
print('***'*30)
print('Reach max epoch: {}!'.format(args.max_epoch))