in crop_yield_prediction/models/deep_gaussian_process/base.py [0:0]
def _train(self, train_images, train_yields, val_images, val_yields, train_steps,
batch_size, starter_learning_rate, weight_decay, l1_weight, patience):
"""Defines the training loop for a model
"""
train_dataset, val_dataset = TensorDataset(train_images, train_yields), TensorDataset(val_images, val_yields)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)
optimizer = torch.optim.Adam([pam for pam in self.model.parameters()],
lr=starter_learning_rate,
weight_decay=weight_decay)
num_epochs = 50
print(f'Training for {num_epochs} epochs')
train_scores = defaultdict(list)
val_scores = defaultdict(list)
step_number = 0
min_loss = np.inf
best_state = self.model.state_dict()
if patience is not None:
epochs_without_improvement = 0
for epoch in range(num_epochs):
self.model.train()
# running train and val scores are only for printing out
# information
running_train_scores = defaultdict(list)
for train_x, train_y in train_dataloader:
optimizer.zero_grad()
pred_y = self.model(train_x)
loss, running_train_scores = l1_l2_loss(pred_y, train_y, l1_weight,
running_train_scores)
loss.backward()
optimizer.step()
train_scores['loss'].append(loss.item())
step_number += 1
if step_number in [4000, 20000]:
for param_group in optimizer.param_groups:
param_group['lr'] /= 10
train_output_strings = []
for key, val in running_train_scores.items():
train_output_strings.append('{}: {}'.format(key, round(np.array(val).mean(), 5)))
running_val_scores = defaultdict(list)
self.model.eval()
with torch.no_grad():
for val_x, val_y, in val_dataloader:
val_pred_y = self.model(val_x)
val_loss, running_val_scores = l1_l2_loss(val_pred_y, val_y, l1_weight,
running_val_scores)
val_scores['loss'].append(val_loss.item())
val_output_strings = []
for key, val in running_val_scores.items():
val_output_strings.append('{}: {}'.format(key, round(np.array(val).mean(), 5)))
print('TRAINING: {}'.format(', '.join(train_output_strings)))
print('VALIDATION: {}'.format(', '.join(val_output_strings)))
epoch_val_loss = np.array(running_val_scores['loss']).mean()
if epoch_val_loss < min_loss:
best_state = self.model.state_dict()
min_loss = epoch_val_loss
if patience is not None:
epochs_without_improvement = 0
elif patience is not None:
epochs_without_improvement += 1
if epochs_without_improvement == patience:
# revert to the best state dict
self.model.load_state_dict(best_state)
print('Early stopping!')
break
self.model.load_state_dict(best_state)
return train_scores, val_scores