def train()

in pretrain/PyTorch/train.py [0:0]


def train(index):
    model.train()
    dataloaders = {}
    i = 0
    global global_step
    datalengths = []
    batchs_per_dataset = []

    # Pretraining datasets
    wiki_pretrain_dataset = PreTrainingDataset(tokenizer=tokenizer,
                                               folder=args.train_path,
                                               logger=logger, max_seq_length=max_seq_length,
                                               index=index, data_type=PretrainDataType.WIKIPEDIA,
                                               max_predictions_per_seq=max_predictions_per_seq,
                                               masked_lm_prob=masked_lm_prob)

    datalengths.append(len(wiki_pretrain_dataset))
    dataloaders[i] = get_dataloader(wiki_pretrain_dataset)

    num_batches_in_dataset = get_effective_batch(len(wiki_pretrain_dataset))
    logger.info('Wikpedia data file: Number of samples {}, number of batches required to process these samples: {}'.format(len(wiki_pretrain_dataset), num_batches_in_dataset))
    
    batchs_per_dataset.append(num_batches_in_dataset)
    i += 1

    logger.info("Training on Wikipedia dataset")

    total_length = sum(datalengths)

    dataset_batches = []
    for i, batch_count in enumerate(batchs_per_dataset):
        dataset_batches.extend([i] * batch_count)
    logger.info('Number of batches to process *all* data samples in this epoch: {}'.format(len(dataset_batches)))
    # shuffle
    random.shuffle(dataset_batches)

    # We don't want the dataset to be n the form of alternate chunks if we have more than
    # one dataset type, instead we want to organize them into contiguous chunks of each
    # data type, hence the multiplication with grad_accumulation_steps with dataset_batch_type
    dataset_picker = []
    for dataset_batch_type in dataset_batches:
        dataset_picker.extend([dataset_batch_type] * gradient_accumulation_steps )

    logger.info('Number of steps to process all batches in this epoch: {}'.format(len(dataset_picker)))
    model.train()

    # Counter of sequences in an "epoch"
    sequences_counter = 0
    global_step_loss = 0

    for step, dataset_type in enumerate(dataset_picker):
        try:
            batch = next(dataloaders[dataset_type])

            sequences_counter += len(batch)

            if n_gpu == 1:
                batch = tuple(t.to(device) for t in batch)  # Move to GPU

            if step > 1 and step % 1000 == 0:
                logger.info("{} Number of sequences processed so far: {} (cumulative in {} steps)".format(datetime.utcnow(), sequences_counter, step))
            # Calculate forward pass
            loss = model.network(batch)

            if n_gpu > 1:
                # this is to average loss for multi-gpu. In DistributedDataParallel
                # setting, we get tuple of losses form all proccesses
                loss = loss.mean()

            if gradient_accumulation_steps > 1:
                loss = loss / gradient_accumulation_steps

            # Enabling  optimized Reduction
            # reduction only happens in backward if this method is called before
            # when using the distributed module
            if accumulate_gradients:
                if use_multigpu_with_single_device_per_process and (step + 1) % gradient_accumulation_steps == 0:
                    model.network.enable_need_reduction()
                else:
                    model.network.disable_need_reduction()
            if fp16:
                optimizer.backward(loss)
            else:
                loss.backward()

            global_step_loss += loss
            if (step + 1) % gradient_accumulation_steps == 0:
                if fp16:
                    # modify learning rate with special warm up BERT uses
                    # if fp16 is False, BertAdam is used that handles this automatically
                    lr_this_step = \
                        job_config.get_learning_rate() * warmup_linear_decay_exp(global_step,
                                                                                 job_config.get_decay_rate(),
                                                                                 job_config.get_decay_step(),
                                                                                 job_config.get_total_training_steps(),
                                                                                 job_config.get_warmup_proportion())
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step

                    # Record the LR against global_step on tensorboard
                    if check_write_log():
                        summary_writer.add_scalar(f'Train/lr', lr_this_step, global_step)
                    
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1
                if check_write_log() and (global_step%args.log_steps == 0):
                    run.log("training_loss", np.float(global_step_loss))
                    run.log("lr_this_step", np.float(lr_this_step))
                    run.log_row("loss over steps", global_step = global_step, loss =  np.float(global_step_loss))
                    run.log_row("lr over steps", global_step = global_step, lr  = np.float(lr_this_step))
                global_step_loss = 0
        except StopIteration:
            continue
        
    logger.info("Completed {} steps".format(step))
    logger.info("Completed processing {} sequences".format(sequences_counter))

    # Run Validation Loss
    if max_seq_length == 512:
        logger.info(f"TRAIN BATCH SIZE: {train_batch_size}")
        return pretrain_validation(index)
    else:
        return None