def evaluate_r2_benchmark()

in compert/api.py [0:0]


def evaluate_r2_benchmark(
    compert_api,
    datasets,
    pert_category,
    pert_category_list
    ):
        scores = pd.DataFrame(columns=[compert_api.covars_key,
                                        compert_api.perturbation_key,
                                        compert_api.dose_key,
                                        'R2_mean', 'R2_mean_DE',
                                        'R2_var', 'R2_var_DE',
                                        'num_cells', 'benchmark', 'method'])

        de_idx = np.where(
                datasets['ood'].var_names.isin(
                    np.array(datasets['ood'].de_genes[pert_category])))[0]
        idx = np.where(datasets['ood'].pert_categories == pert_category)[0]
        y_true = datasets['ood'].genes[idx, :].numpy()
        # true means and variances
        yt_m = y_true.mean(axis=0)
        yt_v = y_true.var(axis=0)

        icond = 0
        if len(idx) > 0:
            for pert_category_predict in pert_category_list:
                if '+' in pert_category_predict:
                    pert1, pert2 = pert_category_predict.split('+')
                    idx_pred1 = np.where(datasets['training'].pert_categories ==\
                        pert1)[0]
                    idx_pred2 = np.where(datasets['training'].pert_categories ==\
                        pert2)[0]

                    y_pred1 = datasets['training'].genes[idx_pred1, :].numpy()
                    y_pred2 = datasets['training'].genes[idx_pred2, :].numpy()

                    x1 = float(pert1.split('_')[2])
                    x2 = float(pert2.split('_')[2])
                    x = float(pert_category.split('_')[2])
                    yp_m1 = y_pred1.mean(axis=0)
                    yp_m2 = y_pred2.mean(axis=0)
                    yp_v1 = y_pred1.var(axis=0)
                    yp_v2 = y_pred2.var(axis=0)

                    yp_m = linear_interp(yp_m1, yp_m2, x1, x2, x)
                    yp_v = linear_interp(yp_v1, yp_v2, x1, x2, x)

#                     yp_m = (y_pred1.mean(axis=0) + y_pred2.mean(axis=0))/2
#                     yp_v = (y_pred1.var(axis=0) + y_pred2.var(axis=0))/2

                else:
                    idx_pred = np.where(datasets['training'].pert_categories ==\
                        pert_category_predict)[0]
                    print(pert_category_predict, len(idx_pred))
                    y_pred = datasets['training'].genes[idx_pred, :].numpy()
                    # predicted means and variances
                    yp_m = y_pred.mean(axis=0)
                    yp_v = y_pred.var(axis=0)

                mean_score = r2_score(yt_m, yp_m)
                var_score = r2_score(yt_v, yp_v)

                mean_score_de = r2_score(yt_m[de_idx], yp_m[de_idx])
                var_score_de = r2_score(yt_v[de_idx], yp_v[de_idx])
                scores.loc[icond] = pert_category.split('_') +\
                    [mean_score, mean_score_de, var_score, var_score_de,\
                        len(idx), pert_category_predict, 'benchmark']
                icond += 1

        return scores