in crop_yield_prediction/train_cross_location.py [0:0]
def train_attention(model, X_dir, X_train_indices_dic, y_train, X_valid_indices_dic, y_valid, X_test_indices_dic, y_test, n_tsteps,
max_index, n_triplets_per_file, tilenet_margin, tilenet_l2, tilenet_ltn, unsup_weight, patience,
optimizer, batch_size, test_batch_size, n_epochs, out_dir, year, exp_idx, log_file):
with open(log_file, 'a') as f:
print('Predict year {}......'.format(year), file=f, flush=True)
print('Train size {}, valid size {}'.format(y_train.shape[0], y_valid.shape[0]), file=f, flush=True)
print('Experiment {}'.format(exp_idx), file=f, flush=True)
cuda = torch.cuda.is_available()
train_dataloader = cross_location_dataloader(X_dir, X_train_indices_dic, y_train, n_tsteps,
max_index, n_triplets_per_file, batch_size, shuffle=True,
num_workers=4)
validation_dataloader = cross_location_dataloader(X_dir, X_valid_indices_dic, y_valid, n_tsteps,
max_index, n_triplets_per_file, batch_size, shuffle=False,
num_workers=4)
test_dataloader = cross_location_dataloader(X_dir, X_test_indices_dic, y_test, n_tsteps,
max_index, n_triplets_per_file, test_batch_size, shuffle=False,
num_workers=4)
valid_rmse_min = np.inf
if patience is not None:
epochs_without_improvement = 0
for epoch_i in range(n_epochs):
print('[ Epoch', epoch_i, ']', file=f, flush=True)
start = time.time()
train_loss = train_epoch(model, train_dataloader, tilenet_margin, tilenet_l2, tilenet_ltn, unsup_weight,
optimizer, cuda)
print(' - {header:12} avg loss: {loss: 8.3f}, supervised loss: {supervised_loss: 8.3f}, '
'unsupervised loss: {unsupervised_loss: 8.3f}, elapse: {elapse:3.3f} min'.
format(header=f"({'Training'})", loss=train_loss['loss'], supervised_loss=train_loss['loss_supervised'],
unsupervised_loss=train_loss['loss_unsupervised'],
elapse=(time.time() - start) / 60), file=f, flush=True)
# if epoch_i in [20, 40]:
# for param_group in optimizer.param_groups:
# param_group['lr'] /= 10
start = time.time()
valid_loss, valid_rmse, valid_r2, valid_corr = eval_epoch(model, validation_dataloader,
tilenet_margin, tilenet_l2, tilenet_ltn, unsup_weight,
cuda)
print(' - {header:12} loss: {loss: 8.3f}, supervised loss: {supervised_loss: 8.3f}, '
'unsupervised loss: {unsupervised_loss: 8.3f}, l_n loss: {l_n: 8.3f}, l_d loss: {l_d: 8.3f}, '
'l_nd loss: {l_nd: 8.3f}, sn_loss: {sn_loss: 8.3f}, tn_loss: {tn_loss: 8.3f}, norm_loss: {norm_loss: 8.3f}, '
'rmse: {rmse: 8.3f}, r2: {r2: 8.3f}, corr: {corr: 8.3f}, elapse: {elapse:3.3f} min'.
format(header=f"({'Validation'})", loss=valid_loss['loss'], supervised_loss=valid_loss['loss_supervised'],
unsupervised_loss=valid_loss['loss_unsupervised'], l_n=valid_loss['l_n'], l_d=valid_loss['l_d'],
l_nd=valid_loss['l_nd'], sn_loss=valid_loss['sn_loss'], tn_loss=valid_loss['tn_loss'], norm_loss=valid_loss['norm_loss'],
rmse=valid_rmse, r2=valid_r2, corr=valid_corr, elapse=(time.time() - start) / 60), file=f, flush=True)
checkpoint = {'epoch': epoch_i, 'model_state_dict': model.state_dict()}
torch.save(checkpoint, '{}/{}_{}_epoch{}.tar'.format(out_dir, exp_idx, year, epoch_i))
if valid_rmse < valid_rmse_min:
eval_test_best_only(test_dataloader, y_test, test_batch_size, model, epoch_i, f)
torch.save(checkpoint, '{}/{}_{}_best.tar'.format(out_dir, exp_idx, year))
print(' - [Info] The checkpoint file has been updated at epoch {}.'.format(epoch_i), file=f, flush=True)
valid_rmse_min = valid_rmse
if patience is not None:
epochs_without_improvement = 0
elif patience is not None:
epochs_without_improvement += 1
if epochs_without_improvement == patience:
print('Early stopping!')
return epoch_i + 1
return n_epochs