crop_yield_prediction/train_c3d.py [214:260]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                                         max_index, n_triplets_per_file, test_batch_size, shuffle=False, num_workers=4)

        valid_rmse_min = np.inf

        if patience is not None:
            epochs_without_improvement = 0
        for epoch_i in range(n_epochs):
            print('[ Epoch', epoch_i, ']', file=f, flush=True)

            start = time.time()
            train_loss = train_epoch(model, train_dataloader, optimizer, cuda)
            print('  - {header:12} avg loss: {loss: 8.3f}, elapse: {elapse:3.3f} min'.
                  format(header=f"({'Training'})", loss=train_loss,
                         elapse=(time.time() - start) / 60), file=f, flush=True)

            # if epoch_i in [20, 40]:
            #     for param_group in optimizer.param_groups:
            #         param_group['lr'] /= 10

            start = time.time()
            valid_loss, valid_rmse, valid_r2, valid_corr = eval_epoch(model, validation_dataloader, cuda)
            print('  - {header:12} loss: {loss: 8.3f}, rmse: {rmse: 8.3f}, r2: {r2: 8.3f}, corr: {corr: 8.3f}, '
                  'elapse: {elapse:3.3f} min'.
                  format(header=f"({'Validation'})", loss=valid_loss,
                         rmse=valid_rmse, r2=valid_r2, corr=valid_corr, elapse=(time.time() - start) / 60), file=f,
                  flush=True)

            checkpoint = {'epoch': epoch_i, 'model_state_dict': model.state_dict()}
            torch.save(checkpoint, '{}/{}_{}_epoch{}.tar'.format(out_dir, exp_idx, year, epoch_i))

            if valid_rmse < valid_rmse_min:
                eval_test_best_only(test_dataloader, y_test, test_batch_size, model, epoch_i, f)

                torch.save(checkpoint, '{}/{}_{}_best.tar'.format(out_dir, exp_idx, year))
                print('    - [Info] The checkpoint file has been updated at epoch {}.'.format(epoch_i), file=f, flush=True)
                valid_rmse_min = valid_rmse

                if patience is not None:
                    epochs_without_improvement = 0
            elif patience is not None:
                epochs_without_improvement += 1

                if epochs_without_improvement == patience:
                    print('Early stopping!')
                    return epoch_i + 1

        return n_epochs
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



crop_yield_prediction/train_cnn_lstm.py [214:260]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
                                         max_index, n_triplets_per_file, test_batch_size, shuffle=False, num_workers=4)

        valid_rmse_min = np.inf

        if patience is not None:
            epochs_without_improvement = 0
        for epoch_i in range(n_epochs):
            print('[ Epoch', epoch_i, ']', file=f, flush=True)

            start = time.time()
            train_loss = train_epoch(model, train_dataloader, optimizer, cuda)
            print('  - {header:12} avg loss: {loss: 8.3f}, elapse: {elapse:3.3f} min'.
                  format(header=f"({'Training'})", loss=train_loss,
                         elapse=(time.time() - start) / 60), file=f, flush=True)

            # if epoch_i in [20, 40]:
            #     for param_group in optimizer.param_groups:
            #         param_group['lr'] /= 10

            start = time.time()
            valid_loss, valid_rmse, valid_r2, valid_corr = eval_epoch(model, validation_dataloader, cuda)
            print('  - {header:12} loss: {loss: 8.3f}, rmse: {rmse: 8.3f}, r2: {r2: 8.3f}, corr: {corr: 8.3f}, '
                  'elapse: {elapse:3.3f} min'.
                  format(header=f"({'Validation'})", loss=valid_loss,
                         rmse=valid_rmse, r2=valid_r2, corr=valid_corr, elapse=(time.time() - start) / 60), file=f,
                  flush=True)

            checkpoint = {'epoch': epoch_i, 'model_state_dict': model.state_dict()}
            torch.save(checkpoint, '{}/{}_{}_epoch{}.tar'.format(out_dir, exp_idx, year, epoch_i))

            if valid_rmse < valid_rmse_min:
                eval_test_best_only(test_dataloader, y_test, test_batch_size, model, epoch_i, f)

                torch.save(checkpoint, '{}/{}_{}_best.tar'.format(out_dir, exp_idx, year))
                print('    - [Info] The checkpoint file has been updated at epoch {}.'.format(epoch_i), file=f, flush=True)
                valid_rmse_min = valid_rmse

                if patience is not None:
                    epochs_without_improvement = 0
            elif patience is not None:
                epochs_without_improvement += 1

                if epochs_without_improvement == patience:
                    print('Early stopping!')
                    return epoch_i + 1

        return n_epochs
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



