def crop_yield_train_semi_transformer()

in crop_yield_train_semi_transformer.py [0:0]


def crop_yield_train_semi_transformer(args, data_dir, model_out_dir, result_out_dir, log_out_dir, start_year, end_year,
                                      n_tsteps, train_years=None):
    batch_size = 64
    test_batch_size = 128
    n_triplets_per_file = 1
    epochs = 50
    attention_nhead = 8
    adam_lr = 0.001
    adam_betas = (0.9, 0.999)
    n_experiment = 2

    neighborhood_radius = args.neighborhood_radius
    distant_radius = args.distant_radius
    weight_decay = args.weight_decay
    tilenet_margin = args.tilenet_margin
    tilenet_l2 = args.tilenet_l2
    tilenet_ltn = args.tilenet_ltn
    tilenet_zdim = args.tilenet_zdim
    attention_layer = args.attention_layer
    attention_dff = args.attention_dff
    sentence_embedding = args.sentence_embedding
    dropout = args.dropout
    unsup_weight = args.unsup_weight
    patience = args.patience if args.patience != 9999 else None
    feature = args.feature
    feature_len = args.feature_len
    query_type = args.query_type

    assert tilenet_zdim % attention_nhead == 0

    params = '{}_nt{}_nr{}_dr{}_wd{}_mar{}_l2{}_ltn{}_zd{}_al{}_adff{}_se{}_dr{}_uw{}_es{}_{}_tyear{}_qt{}'.format(
                                                                                       start_year,
                                                                                       n_tsteps,
                                                                                       neighborhood_radius,
                                                                                       distant_radius,
                                                                                       weight_decay,
                                                                                       tilenet_margin, tilenet_l2,
                                                                                       tilenet_ltn, tilenet_zdim,
                                                                                       attention_layer, attention_dff,
                                                                                       sentence_embedding, dropout,
                                                                                       unsup_weight, patience, feature,
                                                                                       train_years, query_type)

    os.makedirs(log_out_dir, exist_ok=True)
    param_model_out_dir = '{}/{}'.format(model_out_dir, params)
    os.makedirs(param_model_out_dir, exist_ok=True)
    param_result_out_dir = '{}/{}'.format(result_out_dir, params)
    os.makedirs(param_result_out_dir, exist_ok=True)

    if feature == 'all':
        X_dir = '{}/nr_{}'.format(data_dir, neighborhood_radius) if distant_radius is None else \
            '{}/nr_{}_dr{}'.format(data_dir, neighborhood_radius, distant_radius)
    else:
        X_dir = '{}/nr_{}_{}'.format(data_dir, neighborhood_radius, feature) if distant_radius is None else \
            '{}/nr_{}_dr{}_{}'.format(data_dir, neighborhood_radius, distant_radius, feature)

    dim_y = pd.read_csv('{}/dim_y.csv'.format(data_dir))
    dim_y = dim_y.astype({'state': int, 'county': int, 'year': int, 'value': float, 'lat': float, 'lon': float})
    max_index = len(dim_y) - 1

    results = dict()
    for year in range(start_year, end_year + 1):
        print('Predict year {}......'.format(year))

        test_idx = (dim_y['year'] == year)
        valid_idx = (dim_y['year'] == (year - 1))
        if train_years is None:
            train_idx = (dim_y['year'] < (year - 1))
        else:
            train_idx = (dim_y['year'] < (year - 1)) & (dim_y['year'] >= (year - 1 - train_years))

        y_valid, y_train = np.array(dim_y.loc[valid_idx]['value']), np.array(dim_y.loc[train_idx]['value'])
        y_test, dim_test = np.array(dim_y.loc[test_idx]['value']), np.array(dim_y.loc[test_idx][['state', 'county']])

        test_indices = [i for i, x in enumerate(test_idx) if x]
        valid_indices = [i for i, x in enumerate(valid_idx) if x]
        train_indices = [i for i, x in enumerate(train_idx) if x]

        # check if the indices are sequential
        assert all(elem == 1 for elem in [y - x for x, y in zip(test_indices[:-1], test_indices[1:])])
        assert all(elem == 1 for elem in [y - x for x, y in zip(valid_indices[:-1], valid_indices[1:])])
        assert all(elem == 1 for elem in [y - x for x, y in zip(train_indices[:-1], train_indices[1:])])
        print('Train size {}, valid size {}, test size {}'.format(y_train.shape[0], y_valid.shape[0], y_test.shape[0]))

        test_corr_lis, test_r2_lis, test_rmse_lis = [], [], []
        test_prediction_lis = []
        for i in range(n_experiment):
            print('Experiment {}'.format(i))

            semi_transformer = SemiTransformer(
                tn_in_channels=feature_len,
                tn_z_dim=tilenet_zdim,
                tn_warm_start_model=None,
                sentence_embedding=sentence_embedding,
                output_pred=True,
                query_type=query_type,
                attn_n_tsteps=n_tsteps,
                d_word_vec=tilenet_zdim,
                d_model=tilenet_zdim,
                d_inner=attention_dff,
                n_layers=attention_layer,
                n_head=attention_nhead,
                d_k=tilenet_zdim//attention_nhead,
                d_v=tilenet_zdim//attention_nhead,
                dropout=dropout,
                apply_position_enc=True)

            optimizer = optim.Adam(semi_transformer.parameters(), lr=adam_lr, betas=adam_betas, weight_decay=weight_decay)

            trained_epochs = train_attention(model=semi_transformer,
                                             X_dir=X_dir,
                                             X_train_indices=(train_indices[0], train_indices[-1]),
                                             y_train=y_train,
                                             X_valid_indices=(valid_indices[0], valid_indices[-1]),
                                             y_valid=y_valid,
                                             X_test_indices=(test_indices[0], test_indices[-1]),
                                             y_test=y_test,
                                             n_tsteps=n_tsteps,
                                             max_index=max_index,
                                             n_triplets_per_file=n_triplets_per_file,
                                             tilenet_margin=tilenet_margin,
                                             tilenet_l2=tilenet_l2,
                                             tilenet_ltn=tilenet_ltn,
                                             unsup_weight=unsup_weight,
                                             patience=patience,
                                             optimizer=optimizer,
                                             batch_size=batch_size,
                                             test_batch_size=test_batch_size,
                                             n_epochs=epochs,
                                             out_dir=param_model_out_dir,
                                             year=year,
                                             exp_idx=i,
                                             log_file='{}/{}.txt'.format(log_out_dir, params))

            test_prediction, rmse, r2, corr = eval_test(X_dir,
                                                        X_test_indices=(test_indices[0], test_indices[-1]),
                                                        y_test=y_test,
                                                        n_tsteps=n_tsteps,
                                                        max_index=max_index,
                                                        n_triplets_per_file=n_triplets_per_file,
                                                        batch_size=test_batch_size,
                                                        model_dir=param_model_out_dir,
                                                        model=semi_transformer,
                                                        epochs=trained_epochs,
                                                        year=year,
                                                        exp_idx=i,
                                                        log_file='{}/{}.txt'.format(log_out_dir, params))
            test_corr_lis.append(corr)
            test_r2_lis.append(r2)
            test_rmse_lis.append(rmse)

            test_prediction_lis.append(test_prediction)

        test_prediction = np.mean(np.asarray(test_prediction_lis), axis=0)
        np.save('{}/{}.npy'.format(param_result_out_dir, year), test_prediction)
        plot_predict(test_prediction, dim_test, Path('{}/pred_{}.html'.format(param_result_out_dir, year)))
        plot_predict_error(test_prediction, y_test, dim_test, Path('{}/err_{}.html'.format(param_result_out_dir, year)))

        results[year] = {'test_rmse': np.around(np.mean(test_rmse_lis), 3),
                         'test_r2': np.around(np.mean(test_r2_lis), 3),
                         'test_corr': np.around(np.mean(test_corr_lis), 3)}

        output_to_csv_simple(results, param_result_out_dir)