in pretrain/PyTorch/train.py [0:0]
def train(index):
model.train()
dataloaders = {}
i = 0
global global_step
datalengths = []
batchs_per_dataset = []
# Pretraining datasets
wiki_pretrain_dataset = PreTrainingDataset(tokenizer=tokenizer,
folder=args.train_path,
logger=logger, max_seq_length=max_seq_length,
index=index, data_type=PretrainDataType.WIKIPEDIA,
max_predictions_per_seq=max_predictions_per_seq,
masked_lm_prob=masked_lm_prob)
datalengths.append(len(wiki_pretrain_dataset))
dataloaders[i] = get_dataloader(wiki_pretrain_dataset)
num_batches_in_dataset = get_effective_batch(len(wiki_pretrain_dataset))
logger.info('Wikpedia data file: Number of samples {}, number of batches required to process these samples: {}'.format(len(wiki_pretrain_dataset), num_batches_in_dataset))
batchs_per_dataset.append(num_batches_in_dataset)
i += 1
logger.info("Training on Wikipedia dataset")
total_length = sum(datalengths)
dataset_batches = []
for i, batch_count in enumerate(batchs_per_dataset):
dataset_batches.extend([i] * batch_count)
logger.info('Number of batches to process *all* data samples in this epoch: {}'.format(len(dataset_batches)))
# shuffle
random.shuffle(dataset_batches)
# We don't want the dataset to be n the form of alternate chunks if we have more than
# one dataset type, instead we want to organize them into contiguous chunks of each
# data type, hence the multiplication with grad_accumulation_steps with dataset_batch_type
dataset_picker = []
for dataset_batch_type in dataset_batches:
dataset_picker.extend([dataset_batch_type] * gradient_accumulation_steps )
logger.info('Number of steps to process all batches in this epoch: {}'.format(len(dataset_picker)))
model.train()
# Counter of sequences in an "epoch"
sequences_counter = 0
global_step_loss = 0
for step, dataset_type in enumerate(dataset_picker):
try:
batch = next(dataloaders[dataset_type])
sequences_counter += len(batch)
if n_gpu == 1:
batch = tuple(t.to(device) for t in batch) # Move to GPU
if step > 1 and step % 1000 == 0:
logger.info("{} Number of sequences processed so far: {} (cumulative in {} steps)".format(datetime.utcnow(), sequences_counter, step))
# Calculate forward pass
loss = model.network(batch)
if n_gpu > 1:
# this is to average loss for multi-gpu. In DistributedDataParallel
# setting, we get tuple of losses form all proccesses
loss = loss.mean()
if gradient_accumulation_steps > 1:
loss = loss / gradient_accumulation_steps
# Enabling optimized Reduction
# reduction only happens in backward if this method is called before
# when using the distributed module
if accumulate_gradients:
if use_multigpu_with_single_device_per_process and (step + 1) % gradient_accumulation_steps == 0:
model.network.enable_need_reduction()
else:
model.network.disable_need_reduction()
if fp16:
optimizer.backward(loss)
else:
loss.backward()
global_step_loss += loss
if (step + 1) % gradient_accumulation_steps == 0:
if fp16:
# modify learning rate with special warm up BERT uses
# if fp16 is False, BertAdam is used that handles this automatically
lr_this_step = \
job_config.get_learning_rate() * warmup_linear_decay_exp(global_step,
job_config.get_decay_rate(),
job_config.get_decay_step(),
job_config.get_total_training_steps(),
job_config.get_warmup_proportion())
for param_group in optimizer.param_groups:
param_group['lr'] = lr_this_step
# Record the LR against global_step on tensorboard
if check_write_log():
summary_writer.add_scalar(f'Train/lr', lr_this_step, global_step)
optimizer.step()
optimizer.zero_grad()
global_step += 1
if check_write_log() and (global_step%args.log_steps == 0):
run.log("training_loss", np.float(global_step_loss))
run.log("lr_this_step", np.float(lr_this_step))
run.log_row("loss over steps", global_step = global_step, loss = np.float(global_step_loss))
run.log_row("lr over steps", global_step = global_step, lr = np.float(lr_this_step))
global_step_loss = 0
except StopIteration:
continue
logger.info("Completed {} steps".format(step))
logger.info("Completed processing {} sequences".format(sequences_counter))
# Run Validation Loss
if max_seq_length == 512:
logger.info(f"TRAIN BATCH SIZE: {train_batch_size}")
return pretrain_validation(index)
else:
return None