def train_attention()

in crop_yield_prediction/train_semi_transformer.py [0:0]


def train_attention(model, X_dir, X_train_indices, y_train, X_valid_indices, y_valid, X_test_indices, y_test, n_tsteps,
                    max_index, n_triplets_per_file, tilenet_margin, tilenet_l2, tilenet_ltn, unsup_weight, patience,
                    optimizer, batch_size, test_batch_size, n_epochs, out_dir, year, exp_idx, log_file):
    with open(log_file, 'a') as f:
        print('Predict year {}......'.format(year), file=f, flush=True)
        print('Train size {}, valid size {}'.format(y_train.shape[0], y_valid.shape[0]), file=f, flush=True)
        print('Experiment {}'.format(exp_idx), file=f, flush=True)

        cuda = torch.cuda.is_available()

        train_dataloader = semi_cropyield_dataloader(X_dir, X_train_indices[0], X_train_indices[1], y_train, n_tsteps,
                                                     max_index, n_triplets_per_file, batch_size, shuffle=True,
                                                     num_workers=4)
        validation_dataloader = semi_cropyield_dataloader(X_dir, X_valid_indices[0], X_valid_indices[1], y_valid, n_tsteps,
                                                          max_index, n_triplets_per_file, batch_size, shuffle=False,
                                                          num_workers=4)
        test_dataloader = semi_cropyield_dataloader(X_dir, X_test_indices[0], X_test_indices[1], y_test, n_tsteps,
                                                    max_index, n_triplets_per_file, test_batch_size, shuffle=False,
                                                    num_workers=4)

        valid_rmse_min = np.inf

        if patience is not None:
            epochs_without_improvement = 0
        for epoch_i in range(n_epochs):
            print('[ Epoch', epoch_i, ']', file=f, flush=True)

            start = time.time()
            train_loss = train_epoch(model, train_dataloader, tilenet_margin, tilenet_l2, tilenet_ltn, unsup_weight,
                                     optimizer, cuda)
            print('  - {header:12} avg loss: {loss: 8.3f}, supervised loss: {supervised_loss: 8.3f}, '
                  'unsupervised loss: {unsupervised_loss: 8.3f}, elapse: {elapse:3.3f} min'.
                  format(header=f"({'Training'})", loss=train_loss['loss'], supervised_loss=train_loss['loss_supervised'],
                         unsupervised_loss=train_loss['loss_unsupervised'],
                         elapse=(time.time() - start) / 60), file=f, flush=True)

            # if epoch_i in [20, 40]:
            #     for param_group in optimizer.param_groups:
            #         param_group['lr'] /= 10

            start = time.time()
            valid_loss, valid_rmse, valid_r2, valid_corr = eval_epoch(model, validation_dataloader,
                                                                      tilenet_margin, tilenet_l2, tilenet_ltn, unsup_weight,
                                                                      cuda)
            print('  - {header:12} loss: {loss: 8.3f}, supervised loss: {supervised_loss: 8.3f}, '
                  'unsupervised loss: {unsupervised_loss: 8.3f}, l_n loss: {l_n: 8.3f}, l_d loss: {l_d: 8.3f}, '
                  'l_nd loss: {l_nd: 8.3f}, sn_loss: {sn_loss: 8.3f}, tn_loss: {tn_loss: 8.3f}, norm_loss: {norm_loss: 8.3f}, '
                  'rmse: {rmse: 8.3f}, r2: {r2: 8.3f}, corr: {corr: 8.3f}, elapse: {elapse:3.3f} min'.
                  format(header=f"({'Validation'})", loss=valid_loss['loss'], supervised_loss=valid_loss['loss_supervised'],
                         unsupervised_loss=valid_loss['loss_unsupervised'], l_n=valid_loss['l_n'], l_d=valid_loss['l_d'],
                         l_nd=valid_loss['l_nd'], sn_loss=valid_loss['sn_loss'], tn_loss=valid_loss['tn_loss'], norm_loss=valid_loss['norm_loss'],
                         rmse=valid_rmse, r2=valid_r2, corr=valid_corr, elapse=(time.time() - start) / 60), file=f, flush=True)

            checkpoint = {'epoch': epoch_i, 'model_state_dict': model.state_dict()}
            torch.save(checkpoint, '{}/{}_{}_epoch{}.tar'.format(out_dir, exp_idx, year, epoch_i))

            if valid_rmse < valid_rmse_min:
                eval_test_best_only(test_dataloader, y_test, test_batch_size, model, epoch_i, f)

                torch.save(checkpoint, '{}/{}_{}_best.tar'.format(out_dir, exp_idx, year))
                print('    - [Info] The checkpoint file has been updated at epoch {}.'.format(epoch_i), file=f, flush=True)
                valid_rmse_min = valid_rmse

                if patience is not None:
                    epochs_without_improvement = 0
            elif patience is not None:
                epochs_without_improvement += 1

                if epochs_without_improvement == patience:
                    print('Early stopping!')
                    return epoch_i + 1

        return n_epochs