def train()

in abstractive_summarization/src/trainer.py [0:0]


def train(model, training_data, validation_data, optimizer, checkpoint, args, pretrained_model):
    ''' Start training '''
    if args.logging_Euclid_dist:
        t_total = len(training_data) // args.accumulation_steps * 10
        scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
    logger.info('Start training')
    iteration = 0
    if args.break_point_continue:
        iteration = checkpoint['iteration']
    total_loss = 0
    F1 = 0
    for epoch_i in range(args.epoch):
        logger.info('[ Epoch : {}]'.format(epoch_i))
        dist_sum, dist_num = 0.0, 0
        # training part
        model.train()
        for src_ids, decoder_ids, mask, label_ids in training_data:
            iteration += 1
            src_ids = src_ids.cuda()
            decoder_ids = decoder_ids.cuda()
            mask = mask.cuda()
            label_ids = label_ids.cuda()
            # forward
            # optimizer.optimizer.zero_grad()
            loss = model(input_ids=src_ids, attention_mask=mask, decoder_input_ids=decoder_ids, labels=label_ids)[0]
            total_loss += loss.item()
            loss = loss / args.accumulation_steps
            # backward
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
            # loss accumulation
            if (iteration+1) % args.accumulation_steps == 0:
                optimizer.step()
                if args.recadam:
                    scheduler.step()
                model.zero_grad()

            if args.logging_Euclid_dist:
                dist = torch.sum(torch.abs(torch.cat(
                    [p.view(-1) for n, p in model.named_parameters()]) - torch.cat(
                    [p.view(-1) for n, p in pretrained_model.named_parameters()])) ** 2).item()

                dist_sum += dist
                dist_num += 1
            # write to log file
            if iteration % 20 == 0:
                if args.logging_Euclid_dist:
                    logger.info("iteration: {} loss_per_word: {:4f} Euclid dist: {:.6f}".format(iteration, total_loss/20, dist_sum / dist_num))
                else:
                    logger.info("iteration: {} loss_per_word: {:4f} learning rate: {:4f} ".format(iteration, total_loss/20, optimizer.learning_rate))
                total_loss = 0
            # save model
            if iteration % args.save_interval == 0 and iteration > args.start_to_save_iter:
                temp_F1 = evaluation(model, validation_data, args)
                model.train()
                if temp_F1 > F1:
                    logger.info("saving model")
                    if not os.path.exists(args.saving_path + args.data_name):
                        os.makedirs(args.saving_path + args.data_name)
                    model_name = make_file_name(args, iteration)
#                     checkpoint = {'iteration': iteration, 'settings': args, 'optim': optimizer.optimizer.state_dict(), 'model': model.state_dict()}
                    torch.save(model, model_name)
                    F1 = temp_F1
                else:
                    pass