in crop_yield_prediction/train_cross_location.py [0:0]
def eval_epoch(model, validation_dataloader, tilenet_margin, tilenet_l2, tilenet_ltn, unsup_weight, cuda):
''' Epoch operation in evaluation phase '''
model.eval()
if cuda:
model.cuda()
n_batches = len(validation_dataloader)
n_samples = len(validation_dataloader.dataset)
batch_size = validation_dataloader.batch_size
predictions = torch.zeros(n_samples)
# collect y as batch_y has been shuffled
y = torch.zeros(n_samples)
sum_loss_dic = {}
for loss_type in ['loss', 'loss_supervised', 'loss_unsupervised',
'l_n', 'l_d', 'l_nd', 'sn_loss', 'tn_loss', 'norm_loss']:
sum_loss_dic[loss_type] = 0
with torch.no_grad():
for i, (batch_X, batch_y) in enumerate(validation_dataloader):
batch_X, batch_y = prep_data(batch_X, batch_y, cuda)
# forward
emb_triplets, pred = model(batch_X, unsup_weight)
loss_func = torch.nn.MSELoss()
loss_supervised = loss_func(pred, batch_y)
if unsup_weight != 0:
loss_unsupervised, l_n, l_d, l_nd, sn_loss, tn_loss, norm_loss = triplet_loss(emb_triplets,
tilenet_margin, tilenet_l2, tilenet_ltn)
loss = (1 - unsup_weight) * loss_supervised + unsup_weight * loss_unsupervised
else:
loss = loss_supervised
start = i * batch_size
end = start + batch_size if i != n_batches - 1 else n_samples
predictions[start:end] = pred
y[start:end] = batch_y
sum_loss_dic['loss'] += loss.item()
sum_loss_dic['loss_supervised'] += loss_supervised.item()
if unsup_weight != 0:
sum_loss_dic['loss_unsupervised'] += loss_unsupervised.item()
sum_loss_dic['l_n'] += l_n.item()
sum_loss_dic['l_d'] += l_d.item()
sum_loss_dic['l_nd'] += l_nd.item()
sum_loss_dic['sn_loss'] += sn_loss.item()
sum_loss_dic['tn_loss'] += tn_loss.item()
if tilenet_l2 != 0:
sum_loss_dic['norm_loss'] += norm_loss.item()
if cuda:
predictions, y = predictions.cpu(), y.cpu()
predictions, y = predictions.data.numpy(), y.data.numpy()
rmse, r2, corr = cal_performance(predictions, y)
avg_loss_dic = {}
for loss_type in sum_loss_dic.keys():
avg_loss_dic[loss_type] = sum_loss_dic[loss_type] / n_batches
return avg_loss_dic, rmse, r2, corr