def multitask_train()

in abstractive_summarization/src/trainer.py [0:0]


def multitask_train(model_lm, model_cnn, cnn_train_data, cnn_valid_data, tgtdomain_data, optimizer_lm, optimizer_cnn, checkpoint, args):
    tokenizer = BartTokenizer.from_pretrained('facebook/bart-base')
    ''' Start training '''
    logger.info('Start multitask training')
    iteration = 0
    if args.train_from!= '':
        iteration = checkpoint['iteration']
    cnn_loss = 0
    lm_loss = 0
    F1 = 0
    while iteration < args.max_iter:
        iteration += 1
        model_lm.train()
        model_cnn.train()

        ## cnn news summarization training part
        src_ids, decoder_ids, mask, label_ids = next(cnn_train_data)
        src_ids = src_ids.cuda()
        decoder_ids = decoder_ids.cuda()
        mask = mask.cuda()
        label_ids = label_ids.cuda()

        loss = model_cnn(input_ids=src_ids, attention_mask=mask, decoder_input_ids=decoder_ids, labels=label_ids)[0]
        cnn_loss += loss.item()
        loss = loss / args.accumulation_steps
        # backward
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model_cnn.parameters(), args.clip)

        ## denoising language modeling part
        sents = next(tgtdomain_data)
        tokenized_sents = [tokenizer.encode(sent, add_special_tokens=False) for sent in sents]
        decoder_ids = [[tokenizer.bos_token_id] + item for item in tokenized_sents]
        label_ids = [item + [tokenizer.eos_token_id] for item in tokenized_sents]

        noisy_text = add_noise(sents, args.mask_prob)
        inputs_ids = [tokenizer.encode(sent, add_special_tokens=False) for sent in noisy_text]

        # prepare data for training
        inputs_ids = torch.tensor(pad_sents(inputs_ids, pad_token=tokenizer.pad_token_id)[0]).cuda()
        mask = torch.tensor(get_mask(inputs_ids)).cuda()
        decoder_ids = torch.tensor(pad_sents(decoder_ids, pad_token=tokenizer.pad_token_id)[0]).cuda()
        label_ids = torch.tensor(pad_sents(label_ids, pad_token=-100)[0]).cuda()

        # optimize model
        loss = model_lm(input_ids=inputs_ids, attention_mask=mask, decoder_input_ids=decoder_ids, labels=label_ids)[0]
        lm_loss += loss.item()
        loss = loss / args.accumulation_steps
        loss.backward()
        # torch.nn.utils.clip_grad_norm_(model_lm.parameters(), args.clip)

        # loss accumulation
        if (iteration+1) % args.accumulation_steps == 0:
            optimizer_lm.step()
            optimizer_cnn.step()
            model_lm.zero_grad()
            model_cnn.zero_grad()
        # write to log file
        if iteration % 20 == 0:
            logger.info("iteration: {} loss_per_word: {:.6f} loss_lm: {:.6f} learning rate lm: {:.9f} learning rate cnn: {:.9f}".format(iteration, cnn_loss/20, lm_loss/20, optimizer_lm.learning_rate, optimizer_cnn.learning_rate))
            cnn_loss = 0
            lm_loss = 0

        if iteration % 50000 == 0:
            # eval_F1 = evaluation(model, cnn_valid_data, args)
            # logger.info("Iteration: {}. F1 score: {:.4f}".format(iteration, eval_F1))
            logger.info("saving model")
            if not os.path.exists(args.saving_path + args.data_name):
                os.makedirs(args.saving_path + args.data_name)
            model_name = make_file_name(args, iteration)
#             checkpoint_1 = {'iteration': iteration, 'settings': args, 'optim': optimizer_lm.optimizer.state_dict(), 'model_lm': model_lm.state_dict()}
#             checkpoint_2 = {'iteration': iteration, 'settings': args, 'optim': optimizer_cnn.optimizer.state_dict(), 'model': model_cnn.state_dict()}
            torch.save(model_lm, model_name[0])
            torch.save(model_cnn, model_name[1])