in src/bert_train.py [0:0]
def run_train(self, train_iter, validation_iter, model_network, loss_function, optimizer, pos_label):
"""
Runs train...
:param pos_label:
:param validation_iter: Validation set
:param train_iter: Train Data
:param model_network: A neural network
:param loss_function: Pytorch loss function
:param optimizer: Optimiser
"""
best_results = None
start = datetime.datetime.now()
iterations = 0
val_log_template = ' '.join(
'{:>6.0f},{:>5.0f},{:>9.0f},{:>5.0f}/{:<5.0f} {:>7.0f}%,{:>8.6f},{:8.6f},{:12.4f},{:12.4f}'.split(','))
val_log_template = "Run {}".format(val_log_template)
best_score = None
no_improvement_epochs = 0
if self._is_multigpu:
model_network = nn.DataParallel(model_network, device_ids=self.device, output_device=self._default_device)
self._logger.info("Using multi gpu with devices {}, default {} ".format(self.device, self._default_device))
model_network.to(device=self._default_device)
for epoch in range(self.epochs):
losses_train = []
actual_train = []
predicted_train = []
self._logger.debug("Running epoch {}".format(self.epochs))
model_network.zero_grad()
for idx, batch in enumerate(train_iter):
self._logger.debug("Running batch {}".format(idx))
batch_x = batch[0].to(device=self._default_device)
batch_y = batch[1].to(device=self._default_device)
self._logger.debug("batch x shape is {}".format(batch_x.shape))
iterations += 1
# Step 1. train
model_network.train()
# Step 2. Run the forward pass
# words
self._logger.debug("Running forward")
predicted = model_network(batch_x)[0]
# Step 3. Compute loss
self._logger.debug("Running loss")
loss = loss_function(predicted, batch_y) / self.accumulation_steps
loss.backward()
losses_train.append(loss.item())
actual_train.extend(batch_y.cpu().tolist())
predicted_train.extend(torch.max(predicted, 1)[1].view(-1).cpu().tolist())
# Step 4. Only update weights after gradients are accumulated for n steps
if (idx + 1) % self.accumulation_steps == 0:
self._logger.debug("Running optimiser")
optimizer.step()
model_network.zero_grad()
# Print training set results
self._logger.info("Train set result details:")
train_loss = sum(losses_train) / len(losses_train)
train_score = accuracy_score(actual_train, predicted_train)
self._logger.info("Train set result details: {}".format(train_score))
# Print validation set results
self._logger.info("Validation set result details:")
val_actuals, val_predicted, val_loss = self.validate(loss_function, model_network, validation_iter)
val_score = accuracy_score(val_actuals, val_predicted)
self._logger.info("Validation set result details: {} ".format(val_score))
# Snapshot best score
if best_score is None or val_score > best_score:
best_results = (val_score, val_actuals, val_predicted)
self._logger.info(
"Snapshotting because the current score {} is greater than {} ".format(val_score, best_score))
self.snapshot(model_network, model_dir=self.model_dir)
best_score = val_score
no_improvement_epochs = 0
else:
no_improvement_epochs += 1
# Checkpoint
if self.checkpoint_dir and (epoch % self.checkpoint_frequency == 0):
self.create_checkpoint(model_network, self.checkpoint_dir)
# evaluate performance on validation set periodically
self._logger.info(val_log_template.format((datetime.datetime.now() - start).seconds,
epoch, iterations, 1 + len(batch_x), len(train_iter),
100. * (1 + len(batch_x)) / len(train_iter), train_loss,
val_loss, train_score,
val_score))
print("###score: train_loss### {}".format(train_loss))
print("###score: val_loss### {}".format(val_loss))
print("###score: train_score### {}".format(train_score))
print("###score: val_score### {}".format(val_score))
if no_improvement_epochs > self.early_stopping_patience:
self._logger.info("Early stopping.. with no improvement in {}".format(no_improvement_epochs))
break
return best_results