def run_train()

in src/bert_train.py [0:0]


    def run_train(self, train_iter, validation_iter, model_network, loss_function, optimizer, pos_label):
        """
    Runs train...
        :param pos_label:
        :param validation_iter: Validation set
        :param train_iter: Train Data
        :param model_network: A neural network
        :param loss_function: Pytorch loss function
        :param optimizer: Optimiser
        """
        best_results = None
        start = datetime.datetime.now()
        iterations = 0
        val_log_template = ' '.join(
            '{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(','))
        val_log_template = "Run {}".format(val_log_template)

        best_score = None

        no_improvement_epochs = 0

        if self._is_multigpu:
            model_network = nn.DataParallel(model_network, device_ids=self.device, output_device=self._default_device)
            self._logger.info("Using multi gpu with devices {}, default {} ".format(self.device, self._default_device))

        model_network.to(device=self._default_device)

        for epoch in range(self.epochs):
            losses_train = []
            actual_train = []
            predicted_train = []

            self._logger.debug("Running epoch {}".format(self.epochs))
            
            model_network.zero_grad()
            for idx, batch in enumerate(train_iter):
                self._logger.debug("Running batch {}".format(idx))
                batch_x = batch[0].to(device=self._default_device)
                batch_y = batch[1].to(device=self._default_device)

                self._logger.debug("batch x shape is {}".format(batch_x.shape))

                iterations += 1

                # Step 1. train
                model_network.train()
               
                # Step 2. Run the forward pass
                # words
                self._logger.debug("Running forward")
                predicted = model_network(batch_x)[0]

                # Step 3. Compute loss
                self._logger.debug("Running loss")
                loss = loss_function(predicted, batch_y) / self.accumulation_steps
                loss.backward()

                losses_train.append(loss.item())
                actual_train.extend(batch_y.cpu().tolist())
                predicted_train.extend(torch.max(predicted, 1)[1].view(-1).cpu().tolist())

                # Step 4. Only update weights after gradients are accumulated for n steps
                if (idx + 1) % self.accumulation_steps == 0:
                    self._logger.debug("Running optimiser")
                    optimizer.step()
                    model_network.zero_grad()

            # Print training set results
            self._logger.info("Train set result details:")
            train_loss = sum(losses_train) / len(losses_train)
            train_score = accuracy_score(actual_train, predicted_train)
            self._logger.info("Train set result details: {}".format(train_score))

            # Print validation set results
            self._logger.info("Validation set result details:")
            val_actuals, val_predicted, val_loss = self.validate(loss_function, model_network, validation_iter)
            val_score = accuracy_score(val_actuals, val_predicted)
            self._logger.info("Validation set result details: {} ".format(val_score))

            # Snapshot best score
            if best_score is None or val_score > best_score:
                best_results = (val_score, val_actuals, val_predicted)
                self._logger.info(
                    "Snapshotting because the current score {} is greater than {} ".format(val_score, best_score))
                self.snapshot(model_network, model_dir=self.model_dir)
                best_score = val_score
                no_improvement_epochs = 0
            else:
                no_improvement_epochs += 1

            # Checkpoint
            if self.checkpoint_dir and (epoch % self.checkpoint_frequency == 0):
                self.create_checkpoint(model_network, self.checkpoint_dir)

            # evaluate performance on validation set periodically
            self._logger.info(val_log_template.format((datetime.datetime.now() - start).seconds,
                                                      epoch, iterations, 1 + len(batch_x), len(train_iter),
                                                      100. * (1 + len(batch_x)) / len(train_iter), train_loss,
                                                      val_loss, train_score,
                                                      val_score))

            print("###score: train_loss### {}".format(train_loss))
            print("###score: val_loss### {}".format(val_loss))
            print("###score: train_score### {}".format(train_score))
            print("###score: val_score### {}".format(val_score))

            if no_improvement_epochs > self.early_stopping_patience:
                self._logger.info("Early stopping.. with no improvement in {}".format(no_improvement_epochs))
                break

        return best_results