def train()

in abstractive_summarization/src/dapt_pretraining.py [0:0]


    def train(self):
        print('Start finetuning BART language model')
        iteration = 0
        for epoch_i in range(self.epoch):
            self.model.train()
            if self.pretrained_model is not None:
                self.pretrained_model.eval()
            print('[ Epoch : {}]'.format(epoch_i))
            loss_list = []
            dist_sum, dist_num = 0.0, 0
            pbar = tqdm(self.dataloader, total=len(self.dataloader))
            for sents in pbar:
                sents = [self.shorten_sent(sent) for sent in sents]
                iteration += 1
                tokenized_sents = self.tokenize(sents)
                decoder_ids = [[self.tokenizer.bos_token_id] + item for item in tokenized_sents]
                label_ids = [item + [self.tokenizer.eos_token_id] for item in tokenized_sents]
                # print("before:")
                # print(sents[0])
                # print("tokenized sents:")
                # print(tokenized_sents[0])
                # sents: a list of sentence, each item inside is a string
                noisy_text = add_noise(sents, self.mask_probability)
                # noisy_text: a list of sentence, each item inside is a string
                # print("after:")
                # print(noisy_text[0])
                inputs_ids = self.tokenize(noisy_text)
                # print("tokenized noisy text:")
                # print(inputs_ids[0])

                # prepare data for training
                mask = torch.tensor(get_mask(inputs_ids, max_len=512)).cuda()
                inputs_ids = torch.tensor(pad_sents(inputs_ids, pad_token=self.tokenizer.pad_token_id, max_len=512)[0]).cuda()
                decoder_ids = torch.tensor(pad_sents(decoder_ids, pad_token=self.tokenizer.pad_token_id, max_len=512)[0]).cuda()
                label_ids = torch.tensor(pad_sents(label_ids, pad_token=-100, max_len=512)[0]).cuda()
                #optimize model
                loss = self.model(input_ids=inputs_ids, attention_mask=mask, decoder_input_ids=decoder_ids, labels=label_ids)[0]
                loss_list.append(loss.item())
                loss = loss / self.accumulation_steps
                loss.backward()
                if self.args.logging_Euclid_dist:
                    dist = torch.sum(torch.abs(torch.cat(
                        [p.view(-1) for n, p in self.model.named_parameters()]) - torch.cat(
                        [p.view(-1) for n, p in self.pretrained_model.named_parameters()])) ** 2).item()

                    dist_sum += dist
                    dist_num += 1

                if iteration % self.accumulation_steps == 0:
                    if self.args.recadam:
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
                    self.optimizer.step()
                    if self.args.recadam:
                        self.scheduler.step()
                    self.model.zero_grad()
                    loss_list = [np.mean(loss_list)]

                if self.args.logging_Euclid_dist:
#                     pbar.set_description("(Epoch {}) LOSS: {:.6f} Euclid dist: {:.6f} LR: {:.6f}".format(epoch_i, np.mean(loss_list), dist_sum / dist_num, self.scheduler.get_last_lr()[0]))
                    pbar.set_description("(Epoch {}) LOSS: {:.6f} Euclid dist: {:.6f}".format(epoch_i, np.mean(loss_list), dist_sum / dist_num))
                else:
                    pbar.set_description("(Epoch {}) LOSS: {:.6f} LearningRate: {:.10f}".format(epoch_i, np.mean(loss_list), self.optimizer.learning_rate))
                if iteration % args.save_interval == 0:
                    self.save_model(iteration)