in train_extractive_reader.py [0:0]
def run_train(self):
cfg = self.cfg
train_iterator = self.get_data_iterator(
cfg.train_files,
cfg.train.batch_size,
True,
shuffle=True,
shuffle_seed=cfg.seed,
offset=self.start_batch,
)
# num_train_epochs = cfg.train.num_train_epochs - self.start_epoch
logger.info("Total iterations per epoch=%d", train_iterator.max_iterations)
updates_per_epoch = train_iterator.max_iterations // cfg.train.gradient_accumulation_steps
total_updates = updates_per_epoch * cfg.train.num_train_epochs
logger.info(" Total updates=%d", total_updates)
warmup_steps = cfg.train.warmup_steps
if self.scheduler_state:
logger.info("Loading scheduler state %s", self.scheduler_state)
shift = int(self.scheduler_state["last_epoch"])
logger.info("Steps shift %d", shift)
scheduler = get_schedule_linear(
self.optimizer,
warmup_steps,
total_updates,
)
else:
scheduler = get_schedule_linear(self.optimizer, warmup_steps, total_updates)
eval_step = cfg.train.eval_step
logger.info(" Eval step = %d", eval_step)
logger.info("***** Training *****")
global_step = self.start_epoch * updates_per_epoch + self.start_batch
for epoch in range(self.start_epoch, cfg.train.num_train_epochs):
logger.info("***** Epoch %d *****", epoch)
global_step = self._train_epoch(scheduler, epoch, eval_step, train_iterator, global_step)
if cfg.local_rank in [-1, 0]:
logger.info("Training finished. Best validation checkpoint %s", self.best_cp_name)
return