in abstractive_summarization/src/sdpt_pretraining.py [0:0]
def train(model, training_data, optimizer, checkpoint, args, pretrained_model):
''' Start training '''
if args.logging_Euclid_dist:
t_total = len(training_data) // args.accumulation_steps * 10
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)
logger.info('Start training')
iteration = 0
if args.break_point_continue:
iteration = checkpoint['iteration']
total_loss = 0
F1 = 0
for epoch_i in range(args.epoch):
logger.info('[ Epoch : {}]'.format(epoch_i))
dist_sum, dist_num = 0.0, 0
# training part
model.train()
for src_ids, decoder_ids, mask, label_ids in training_data:
iteration += 1
src_ids = src_ids.cuda()
decoder_ids = decoder_ids.cuda()
mask = mask.cuda()
label_ids = label_ids.cuda()
# forward
# optimizer.optimizer.zero_grad()
loss = model(input_ids=src_ids, attention_mask=mask, decoder_input_ids=decoder_ids, labels=label_ids)[0]
total_loss += loss.item()
loss = loss / args.accumulation_steps
# backward
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip)
# loss accumulation
if (iteration+1) % args.accumulation_steps == 0:
optimizer.step()
if args.recadam:
scheduler.step()
model.zero_grad()
if args.logging_Euclid_dist:
dist = torch.sum(torch.abs(torch.cat(
[p.view(-1) for n, p in model.named_parameters()]) - torch.cat(
[p.view(-1) for n, p in pretrained_model.named_parameters()])) ** 2).item()
dist_sum += dist
dist_num += 1
# write to log file
if iteration % 20 == 0:
if args.logging_Euclid_dist:
logger.info("iteration: {} loss_per_word: {:4f} Euclid dist: {:.6f}".format(iteration, total_loss/20, dist_sum / dist_num))
else:
logger.info("iteration: {} loss_per_word: {:4f} learning rate: {:4f} ".format(iteration, total_loss/20, optimizer.learning_rate))
total_loss = 0
# save model
if iteration % args.save_interval == 0 and iteration > args.start_to_save_iter:
print('=====Saving checkpoint=====')
model_name = args.saving_path + "/{}_{}.chkpt".format(args.data_name, iteration)
torch.save(model, model_name)
else:
pass