in train_dense_encoder.py [0:0]
def run_train(self):
cfg = self.cfg
train_iterator = self.get_data_iterator(
cfg.train.batch_size,
True,
shuffle=True,
shuffle_seed=cfg.seed,
offset=self.start_batch,
rank=cfg.local_rank,
)
max_iterations = train_iterator.get_max_iterations()
logger.info(" Total iterations per epoch=%d", max_iterations)
if max_iterations == 0:
logger.warning("No data found for training.")
return
updates_per_epoch = train_iterator.max_iterations // cfg.train.gradient_accumulation_steps
total_updates = updates_per_epoch * cfg.train.num_train_epochs
logger.info(" Total updates=%d", total_updates)
warmup_steps = cfg.train.warmup_steps
if self.scheduler_state:
# TODO: ideally we'd want to just call
# scheduler.load_state_dict(self.scheduler_state)
# but it doesn't work properly as of now
logger.info("Loading scheduler state %s", self.scheduler_state)
shift = int(self.scheduler_state["last_epoch"])
logger.info("Steps shift %d", shift)
scheduler = get_schedule_linear(
self.optimizer,
warmup_steps,
total_updates,
steps_shift=shift,
)
else:
scheduler = get_schedule_linear(self.optimizer, warmup_steps, total_updates)
eval_step = math.ceil(updates_per_epoch / cfg.train.eval_per_epoch)
logger.info(" Eval step = %d", eval_step)
logger.info("***** Training *****")
for epoch in range(self.start_epoch, int(cfg.train.num_train_epochs)):
logger.info("***** Epoch %d *****", epoch)
self._train_epoch(scheduler, epoch, eval_step, train_iterator)
if cfg.local_rank in [-1, 0]:
logger.info("Training finished. Best validation checkpoint %s", self.best_cp_name)