in crop_yield_prediction/models/deep_gaussian_process/base.py [0:0]
def _predict(self, train_images, train_yields, train_locations, train_indices,
train_years, test_images, test_yields, test_locations, test_indices,
test_years, batch_size):
"""
Predict on the training and validation data. Optionally, return the last
feature vector of the model.
"""
train_dataset = TensorDataset(train_images, train_yields,
train_locations, train_indices,
train_years)
test_dataset = TensorDataset(test_images, test_yields,
test_locations, test_indices,
test_years)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)
results = defaultdict(list)
self.model.eval()
with torch.no_grad():
for train_im, train_yield, train_loc, train_idx, train_year in train_dataloader:
model_output = self.model(train_im,
return_last_dense=True if (self.gp is not None) else False)
if self.gp is not None:
pred, feat = model_output
if feat.device != 'cpu':
feat = feat.cpu()
results['train_feat'].append(feat.numpy())
else:
pred = model_output
results['train_pred'].extend(pred.squeeze(1).tolist())
results['train_real'].extend(train_yield.squeeze(1).tolist())
results['train_loc'].append(train_loc.numpy())
results['train_indices'].append(train_idx.numpy())
results['train_years'].extend(train_year.tolist())
for test_im, test_yield, test_loc, test_idx, test_year in test_dataloader:
model_output = self.model(test_im,
return_last_dense=True if (self.gp is not None) else False)
if self.gp is not None:
pred, feat = model_output
if feat.device != 'cpu':
feat = feat.cpu()
results['test_feat'].append(feat.numpy())
else:
pred = model_output
results['test_pred'].extend(pred.squeeze(1).tolist())
results['test_real'].extend(test_yield.squeeze(1).tolist())
results['test_loc'].append(test_loc.numpy())
results['test_indices'].append(test_idx.numpy())
results['test_years'].extend(test_year.tolist())
for key in results:
if key in ['train_feat', 'test_feat', 'train_loc',
'test_loc', 'train_indices', 'test_indices']:
results[key] = np.concatenate(results[key], axis=0)
else:
results[key] = np.array(results[key])
return results