in abstractive_summarization/src/tapt_pretraining.py [0:0]
def train(self):
print('Start finetuning BART language model')
iteration = 0
for epoch_i in range(self.epoch):
self.model.train()
if self.pretrained_model is not None:
self.pretrained_model.eval()
print('[ Epoch : {}]'.format(epoch_i))
loss_list = []
dist_sum, dist_num = 0.0, 0
pbar = tqdm(self.dataloader, total=len(self.dataloader))
for sents in pbar:
sents = [self.shorten_sent(sent) for sent in sents]
iteration += 1
tokenized_sents = self.tokenize(sents)
decoder_ids = [[self.tokenizer.bos_token_id] + item for item in tokenized_sents]
label_ids = [item + [self.tokenizer.eos_token_id] for item in tokenized_sents]
# print("before:")
# print(sents[0])
# print("tokenized sents:")
# print(tokenized_sents[0])
# sents: a list of sentence, each item inside is a string
noisy_text = add_noise(sents, self.mask_probability)
# noisy_text: a list of sentence, each item inside is a string
# print("after:")
# print(noisy_text[0])
inputs_ids = self.tokenize(noisy_text)
# print("tokenized noisy text:")
# print(inputs_ids[0])
# prepare data for training
mask = torch.tensor(get_mask(inputs_ids, max_len=512)).cuda()
inputs_ids = torch.tensor(pad_sents(inputs_ids, pad_token=self.tokenizer.pad_token_id, max_len=512)[0]).cuda()
decoder_ids = torch.tensor(pad_sents(decoder_ids, pad_token=self.tokenizer.pad_token_id, max_len=512)[0]).cuda()
label_ids = torch.tensor(pad_sents(label_ids, pad_token=-100, max_len=512)[0]).cuda()
#optimize model
loss = self.model(input_ids=inputs_ids, attention_mask=mask, decoder_input_ids=decoder_ids, labels=label_ids)[0]
loss_list.append(loss.item())
loss = loss / self.accumulation_steps
loss.backward()
if self.args.logging_Euclid_dist:
dist = torch.sum(torch.abs(torch.cat(
[p.view(-1) for n, p in self.model.named_parameters()]) - torch.cat(
[p.view(-1) for n, p in self.pretrained_model.named_parameters()])) ** 2).item()
dist_sum += dist
dist_num += 1
if iteration % self.accumulation_steps == 0:
if self.args.recadam:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
if self.args.recadam:
self.scheduler.step()
self.model.zero_grad()
loss_list = [np.mean(loss_list)]
if self.args.logging_Euclid_dist:
# pbar.set_description("(Epoch {}) LOSS: {:.6f} Euclid dist: {:.6f} LR: {:.6f}".format(epoch_i, np.mean(loss_list), dist_sum / dist_num, self.scheduler.get_last_lr()[0]))
pbar.set_description("(Epoch {}) LOSS: {:.6f} Euclid dist: {:.6f}".format(epoch_i, np.mean(loss_list), dist_sum / dist_num))
else:
pbar.set_description("(Epoch {}) LOSS: {:.6f} LearningRate: {:.10f}".format(epoch_i, np.mean(loss_list), self.optimizer.learning_rate))
if iteration % args.save_interval == 0:
self.save_model(iteration)