in crop_yield_prediction/train_semi_transformer.py [0:0]
def eval_test(X_dir, X_test_indices, y_test, n_tsteps, max_index, n_triplets_per_file, batch_size, model_dir, model, epochs, year,
exp_idx, log_file):
with open(log_file, 'a') as f:
print('Predict year {}'.format(year), file=f, flush=True)
print('Test size {}'.format(y_test.shape[0]), file=f, flush=True)
print('Experiment {}'.format(exp_idx), file=f, flush=True)
cuda = torch.cuda.is_available()
models = []
for epoch_i in range(epochs):
models.append('{}/{}_{}_epoch{}.tar'.format(model_dir, exp_idx, year, epoch_i))
best_model = '{}/{}_{}_best.tar'.format(model_dir, exp_idx, year)
models.append(best_model)
for model_file in models:
if cuda:
if torch.cuda.device_count() > 1:
print("Let's use", torch.cuda.device_count(), "GPUs!")
model = torch.nn.DataParallel(model)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)
checkpoint = torch.load(model_file) if cuda else torch.load(model_file, map_location='cpu')
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
if cuda:
model.cuda()
test_dataloader = semi_cropyield_dataloader(X_dir, X_test_indices[0], X_test_indices[1], y_test, n_tsteps,
max_index, n_triplets_per_file, batch_size, shuffle=False, num_workers=4)
n_batches = len(test_dataloader)
n_samples = len(y_test)
predictions = torch.zeros(n_samples)
with torch.no_grad():
for i, (batch_X, batch_y) in enumerate(test_dataloader):
batch_X, batch_y = prep_data(batch_X, batch_y, cuda)
# forward
_, pred = model(batch_X, unsup_weight=0)
start = i * batch_size
end = start + batch_size if i != n_batches - 1 else n_samples
predictions[start:end] = pred
if cuda:
predictions = predictions.cpu()
predictions = predictions.data.numpy()
rmse, r2, corr = cal_performance(predictions, y_test)
if 'epoch' in model_file:
print(' - {header:12} epoch: {epoch: 5}, rmse: {rmse: 8.3f}, r2: {r2: 8.3f}, corr: {corr: 8.3f}'.
format(header=f"({'Test'})", epoch=checkpoint['epoch'], rmse=rmse, r2=r2, corr=corr), file=f, flush=True)
else:
print(' - {header:12} best selected based on validation set, '
'rmse: {rmse: 8.3f}, r2: {r2: 8.3f}, corr: {corr: 8.3f}'.
format(header=f"({'Test'})", rmse=rmse, r2=r2, corr=corr), file=f, flush=True)
return predictions, rmse, r2, corr