in low_rank_comparisons/src/gpt2_ft.py [0:0]
def train_validate(model, optimizer, scheduler, train_loader, valid_loader, args, train_step = 0, epoch = 0):
model.train()
avg_lm_loss = AverageMeter()
print('start to train the model................', epoch)
log_start_time = time.time()
best_val_ppl = None
train_loader.sampler.set_epoch(epoch)
for idx, data in enumerate(train_loader):
data = {key: value for key, value in data.items()}
_input = data['input'].to(args.device)
_target = data['target'].to(args.device)
_msk = data['mask'].to(args.device)
_lm_logits, _lm_loss = model(_input, lm_labels=_target, lm_mask=_msk, label_smooth=args.label_smooth)
_lm_loss = _lm_loss.mean()
train_step += 1
is_update = True if train_step % args.grad_acc == 0 else False
avg_lm_loss.update(_lm_loss.item())
optimizer_step(_lm_loss/(args.grad_acc), optimizer, model, scheduler, args, is_update=is_update)
if train_step % args.log_interval == 0:
elapsed = time.time() - log_start_time
log_str = '| epoch {:3d} step {:>8d} | {:>6d} batches | lr {:.3g}' \
'| ms/batch {:5.2f} | loss {:5.2f} | avg loss {:5.2f} | ppl {:5.2f}'.format(
epoch, train_step, idx + 1, optimizer.param_groups[0]['lr'],
elapsed * 1000 / args.log_interval, avg_lm_loss.val, avg_lm_loss.avg, math.exp(avg_lm_loss.avg))
if args.rank == 0:
print(log_str)
log_start_time = time.time()
avg_lm_loss.reset()
if train_step % args.save_interval == 0:
if args.rank == 0:
model_path = os.path.join(args.work_dir, 'model.'+str(train_step)+'.pt')
print('saving checkpoint', model_path)
torch.save({'model_state_dict': model.state_dict()}, model_path)
distributed_sync(args)
# evaluation interval
if train_step % args.eval_interval == 0:
eval_start_time = time.time()
valid_loss, valid_ppl = evaluate(model, valid_loader, args)
if best_val_ppl is None or valid_ppl < best_val_ppl:
best_val_ppl = valid_ppl
log_str = '| Eval {:3d} at step {:>8d} | time: {:5.2f}s ' \
'| valid loss {:5.2f} | valid ppl {:5.2f} | best ppl {:5.2f} '.format(
train_step // args.eval_interval, train_step,
(time.time() - eval_start_time), valid_loss, valid_ppl, best_val_ppl)
if args.rank == 0:
print('-' * 100)
print(log_str)
print('-' * 100)
model.train()
distributed_sync(args)
if train_step == args.max_step:
break
if args.rank == 0:
model_path = os.path.join(args.work_dir, 'model.'+str(train_step)+'.pt')
print('saving checkpoint', model_path)
torch.save({'model_state_dict': model.state_dict()}, model_path)
distributed_sync(args)
return train_step