def eval_lin_models()

in neural/linear/__main__.py [0:0]


def eval_lin_models(subject,
                    data_path,
                    results_path_reg,
                    results_path_autoreg,
                    n_init=40,
                    tune_models=True,
                    with_init=True,
                    with_forcing=True,
                    shuffle=False):
    # Load dataset
    data = load_torch_megs(data_path, subject=subject)

    # Load the necessary to reverse PCA
    pca_mat = data.pca_mats[0]
    mean = data.means[0]
    scaler = data.meg_scalers[0]

    # Get train / valid / test sets, tensor shape [N, C, T]
    meg_train = data.train_sets[0].meg.numpy()
    meg_valid = data.valid_sets[0].meg.numpy()
    meg_test = data.test_sets[0].meg.numpy()

    forcing_keys = data.train_sets[0].forcings.keys()
    forcing_train = np.concatenate(list([data.train_sets[0].forcings[k]
                                        for k in forcing_keys]), axis=1)
    forcing_valid = np.concatenate(list([data.valid_sets[0].forcings[k]
                                        for k in forcing_keys]), axis=1)
    forcing_test = np.concatenate(list([data.test_sets[0].forcings[k]
                                        for k in forcing_keys]), axis=1)

    if not with_forcing:
        forcing_train = np.zeros_like(forcing_train)
        forcing_valid = np.zeros_like(forcing_valid)
        forcing_test = np.zeros_like(forcing_test)

    # Reformat [N, T, C]
    [meg_train, meg_valid, meg_test,
     forcing_train, forcing_valid, forcing_test] = [
        np.swapaxes(elem, 1, 2) for elem in [meg_train, meg_valid, meg_test,
                                             forcing_train, forcing_valid, forcing_test]
    ]

    ######################
    # LIN REG
    ######################

    # Instantiate
    rfield = RField(lag_u=260, penal_weight=1.8)

    # Tune hyperparameter on valid set
    alpha_scores = list()
    alphas = np.logspace(-3, 3, 5)

    if tune_models:

        for alpha in alphas:
            rfield.model.estimator = alpha
            rfield.fit(forcing_train, meg_train)
            meg_pred = rfield.predict(forcing_valid)
            meg_true = meg_valid
            # computing metrics
            alpha_score = get_metrics(meg_true, meg_pred)
            alpha_scores.append(alpha_score.mean())
        # plt.plot(np.log10(alphas), alpha_scores)
        # plt.ylabel('r')
        # plt.show()
        # plt.close()

        alpha = alphas[np.argmax(alpha_scores)]
        rfield.model.estimator = alpha

    # Retrain on train + valid set, save model
    rfield.fit(forcing_train, meg_train)
    torch.save(rfield, results_path_reg / f"model_trf_subject_{subject}.th")

    # Predict on test set
    meg_pred = rfield.predict(forcing_test)
    meg_true = meg_test

    # Reverse PCA
    meg_pred = inverse(mean, scaler, pca_mat, meg_pred)
    meg_true = inverse(mean, scaler, pca_mat, meg_true)

    # Save plot
    report_correl(meg_true, meg_pred, results_path_reg / "reg.png", 0)

    # Save prediction sample from all subjects
    torch.save({"meg_pred_epoch": meg_pred[0],
                "meg_true_epoch": meg_true[0],
                "meg_pred_evoked": meg_pred.mean(0),
                "meg_true_evoked": meg_true.mean(0)},
                results_path_reg / f"meg_prediction_subject_{subject}.th")

    # Compute metric (Pearson R)
    score_linreg = get_metrics(meg_true, meg_pred)

    # Permutation Feature Importance
    shuffled = {}
    to_shuffle = ["word_lengths", "word_freqs"] if shuffle else []
    for name in to_shuffle:

        # Permute forcing (via original torch forcing)
        forcing_test_torch = data.test_sets[0].forcings
        shuffle_forcings(forcing_test_torch, name)
        forcing_test_shuffle = np.concatenate(list([forcing_test_torch[k]
                                                    for k in forcing_keys]), axis=1)
        forcing_test_shuffle = np.swapaxes(forcing_test_shuffle, 1, 2)
        if not with_forcing:
            forcing_test_shuffle = np.zeros_like(forcing_test_shuffle)

        # Predict on test set
        meg_pred = rfield.predict(forcing_test_shuffle)
        meg_true = meg_test

        # Reverse pca
        meg_pred = inverse(mean, scaler, pca_mat, meg_pred)
        meg_true = inverse(mean, scaler, pca_mat, meg_true)

        # Compute metric (Pearson R)
        score_tmp = get_metrics(meg_true, meg_pred)
        shuffled[name] = score_tmp

    ######################
    # LIN AUTOREG
    ######################

    # Instantiate
    ridge = ARX(lag_u=n_init, lag_y=n_init, solver="ridge", penal_weight=1.8, scaling=False)

    # Tune hyperparameter on valid set
    alpha_scores = list()

    if tune_models:

        for alpha in alphas:
            ridge.penal_weight = alpha
            ridge.fit(forcing_train, meg_train)
            meg_init = np.zeros_like(meg_valid)
            if with_init:
                meg_init[:, :n_init, :] = meg_valid[:, :n_init, :]
            meg_pred = ridge.predict(
                forcing_valid, meg_init, start=n_init, eval="unrolled")
            meg_true = meg_valid
            # computing metrics
            alpha_score = get_metrics(meg_true, meg_pred)
            alpha_scores.append(alpha_score.mean())

        # plt.plot(np.log10(alphas), alpha_scores)
        # plt.ylabel('r')
        # plt.show()
        # plt.close()

        alpha = alphas[np.argmax(alpha_scores)]
        ridge.penal_weight = alpha

    # Retrain on train + valid set, save model
    ridge.fit(forcing_train, meg_train)
    torch.save(ridge, results_path_autoreg / f"model_rtrf_subject_{subject}.th")

    # Predict on test set
    meg_init = np.zeros_like(meg_test)
    if with_init:
        meg_init[:, :n_init, :] = meg_test[:, :n_init, :]
    meg_pred = ridge.predict(forcing_test, meg_init, start=n_init, eval="unrolled")
    meg_true = meg_test

    # Reverse PCA
    meg_pred = inverse(mean, scaler, pca_mat, meg_pred)
    meg_true = inverse(mean, scaler, pca_mat, meg_true)

    # Save plot
    report_correl(meg_true, meg_pred, results_path_autoreg / "autoreg.png", n_init)

    # Save prediction sample for all subjects
    torch.save({"meg_pred_epoch": meg_pred[0],
                "meg_true_epoch": meg_true[0],
                "meg_pred_evoked": meg_pred.mean(0),
                "meg_true_evoked": meg_true.mean(0)},
                results_path_autoreg / f"meg_prediction_subject_{subject}.th")

    # Compute metric (Pearson R)
    score_linautoreg = get_metrics(meg_true, meg_pred)

    # TODO: add Permutation Feature Importance for linear autoreg
    score_linautoreg = np.zeros_like(score_linreg)

    return score_linreg, score_linautoreg, shuffled