in train_retriever.py [0:0]
def train(model, optimizer, scheduler, global_step,
train_dataset, dev_dataset, opt, collator, best_eval_loss):
if opt.is_main:
try:
tb_logger = torch.utils.tensorboard.SummaryWriter(Path(opt.checkpoint_dir)/opt.name)
except:
tb_logger = None
logger.warning('Tensorboard is not available.')
train_sampler = DistributedSampler(train_dataset) if opt.is_distributed else RandomSampler(train_dataset)
train_dataloader = DataLoader(
train_dataset,
sampler=train_sampler,
batch_size=opt.per_gpu_batch_size,
drop_last=True,
num_workers=10,
collate_fn=collator
)
loss, curr_loss = 0.0, 0.0
epoch = 1
model.train()
while global_step < opt.total_steps:
if opt.is_distributed > 1:
train_sampler.set_epoch(epoch)
epoch += 1
for i, batch in enumerate(train_dataloader):
global_step += 1
(idx, question_ids, question_mask, passage_ids, passage_mask, gold_score) = batch
_, _, _, train_loss = model(
question_ids=question_ids.cuda(),
question_mask=question_mask.cuda(),
passage_ids=passage_ids.cuda(),
passage_mask=passage_mask.cuda(),
gold_score=gold_score.cuda(),
)
train_loss.backward()
if global_step % opt.accumulation_steps == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), opt.clip)
optimizer.step()
scheduler.step()
model.zero_grad()
train_loss = src.util.average_main(train_loss, opt)
curr_loss += train_loss.item()
if global_step % opt.eval_freq == 0:
eval_loss, inversions, avg_topk, idx_topk = evaluate(model, dev_dataset, collator, opt)
if eval_loss < best_eval_loss:
best_eval_loss = eval_loss
if opt.is_main:
src.util.save(model, optimizer, scheduler, global_step, best_eval_loss, opt, dir_path, 'best_dev')
model.train()
if opt.is_main:
log = f"{global_step} / {opt.total_steps}"
log += f" -- train: {curr_loss/opt.eval_freq:.6f}"
log += f", eval: {eval_loss:.6f}"
log += f", inv: {inversions:.1f}"
log += f", lr: {scheduler.get_last_lr()[0]:.6f}"
for k in avg_topk:
log += f" | avg top{k}: {100*avg_topk[k]:.1f}"
for k in idx_topk:
log += f" | idx top{k}: {idx_topk[k]:.1f}"
logger.info(log)
if tb_logger is not None:
tb_logger.add_scalar("Evaluation", eval_loss, global_step)
tb_logger.add_scalar("Training", curr_loss / (opt.eval_freq), global_step)
curr_loss = 0
if opt.is_main and global_step % opt.save_freq == 0:
src.util.save(model, optimizer, scheduler, global_step, best_eval_loss, opt, dir_path, f"step-{global_step}")
if global_step > opt.total_steps:
break