in abstractive_summarization/src/trainer.py [0:0]
def multitask_train(model_lm, model_cnn, cnn_train_data, cnn_valid_data, tgtdomain_data, optimizer_lm, optimizer_cnn, checkpoint, args):
tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
''' Start training '''
logger.info('Start multitask training')
iteration = 0
if args.train_from!= '':
iteration = checkpoint['iteration']
cnn_loss = 0
lm_loss = 0
F1 = 0
while iteration < args.max_iter:
iteration += 1
model_lm.train()
model_cnn.train()
## cnn news summarization training part
src_ids, decoder_ids, mask, label_ids = next(cnn_train_data)
src_ids = src_ids.cuda()
decoder_ids = decoder_ids.cuda()
mask = mask.cuda()
label_ids = label_ids.cuda()
loss = model_cnn(input_ids=src_ids, attention_mask=mask, decoder_input_ids=decoder_ids, labels=label_ids)[0]
cnn_loss += loss.item()
loss = loss / args.accumulation_steps
# backward
loss.backward()
# torch.nn.utils.clip_grad_norm_(model_cnn.parameters(), args.clip)
## denoising language modeling part
sents = next(tgtdomain_data)
tokenized_sents = [tokenizer.encode(sent, add_special_tokens=False) for sent in sents]
decoder_ids = [[tokenizer.bos_token_id] + item for item in tokenized_sents]
label_ids = [item + [tokenizer.eos_token_id] for item in tokenized_sents]
noisy_text = add_noise(sents, args.mask_prob)
inputs_ids = [tokenizer.encode(sent, add_special_tokens=False) for sent in noisy_text]
# prepare data for training
inputs_ids = torch.tensor(pad_sents(inputs_ids, pad_token=tokenizer.pad_token_id)[0]).cuda()
mask = torch.tensor(get_mask(inputs_ids)).cuda()
decoder_ids = torch.tensor(pad_sents(decoder_ids, pad_token=tokenizer.pad_token_id)[0]).cuda()
label_ids = torch.tensor(pad_sents(label_ids, pad_token=-100)[0]).cuda()
# optimize model
loss = model_lm(input_ids=inputs_ids, attention_mask=mask, decoder_input_ids=decoder_ids, labels=label_ids)[0]
lm_loss += loss.item()
loss = loss / args.accumulation_steps
loss.backward()
# torch.nn.utils.clip_grad_norm_(model_lm.parameters(), args.clip)
# loss accumulation
if (iteration+1) % args.accumulation_steps == 0:
optimizer_lm.step()
optimizer_cnn.step()
model_lm.zero_grad()
model_cnn.zero_grad()
# write to log file
if iteration % 20 == 0:
logger.info("iteration: {} loss_per_word: {:.6f} loss_lm: {:.6f} learning rate lm: {:.9f} learning rate cnn: {:.9f}".format(iteration, cnn_loss/20, lm_loss/20, optimizer_lm.learning_rate, optimizer_cnn.learning_rate))
cnn_loss = 0
lm_loss = 0
if iteration % 50000 == 0:
# eval_F1 = evaluation(model, cnn_valid_data, args)
# logger.info("Iteration: {}. F1 score: {:.4f}".format(iteration, eval_F1))
logger.info("saving model")
if not os.path.exists(args.saving_path + args.data_name):
os.makedirs(args.saving_path + args.data_name)
model_name = make_file_name(args, iteration)
# checkpoint_1 = {'iteration': iteration, 'settings': args, 'optim': optimizer_lm.optimizer.state_dict(), 'model_lm': model_lm.state_dict()}
# checkpoint_2 = {'iteration': iteration, 'settings': args, 'optim': optimizer_cnn.optimizer.state_dict(), 'model': model_cnn.state_dict()}
torch.save(model_lm, model_name[0])
torch.save(model_cnn, model_name[1])