def metrics_subfig()

in prototypes/orthogonal_forests/comparison_plots.py [0:0]


def metrics_subfig(dfs, ax, metric, c_scheme=0):
    if c_scheme == 0:
        palette = plt.get_cmap('Set1')
    else:
        palette = plt.get_cmap('tab20b')
    if metric == "bias":
        biases = np.zeros((len(dfs[0]), len(dfs)))
        for i, df in enumerate(dfs):
            treatment_effects = df[[c for c in df.columns if bool(re.search('TE_[0-9]+', c))]]
            bias = np.abs(np.mean(treatment_effects, axis=1) - df["TE_hat"])
            biases[:, i] = np.abs(np.mean(treatment_effects, axis=1) - df["TE_hat"])
        vparts = ax.violinplot(biases, showmedians=True)
        ax.set_title("Bias")
    elif metric=="variance":
        variances = np.zeros((len(dfs[0]), len(dfs)))
        for i, df in enumerate(dfs):
            treatment_effects = df[[c for c in df.columns if bool(re.search('TE_[0-9]+', c))]]
            variance = np.std(treatment_effects, axis=1)
            variances[:, i] = np.std(treatment_effects, axis=1)
        vparts = ax.violinplot(variances, showmedians=True)
        ax.set_title("Variance")
    elif metric=="rmse":
        rmses = np.zeros((len(dfs[0]), len(dfs)))
        for i, df in enumerate(dfs):
            treatment_effects = df[[c for c in df.columns if bool(re.search('TE_[0-9]+', c))]]
            rmse = np.mean(((treatment_effects.T - df["TE_hat"].values).T)**2, axis=1)
            rmses[:, i] = np.mean(((treatment_effects.T - df["TE_hat"].values).T)**2, axis=1)
        vparts = ax.violinplot(rmses, showmedians=True)
        ax.set_title("RMSE")
    elif metric == "R2":
        r2_scores = []
        for i, df in enumerate(dfs):
            r2_scores.append(get_r2(df))
        vparts = ax.violinplot(r2_scores, showmedians=True)
        ax.set_title("$R^2$")
    else:
        print("No such metric")
        return 0
    cs = [0, 3, 12, 14, 15, 4, 6]
    ax.set_xticks([])
    for i, pc in enumerate(vparts['bodies']):
        if i < 5:
            c = i
        else:
            c = i+1
        if c_scheme == 1:
            c = cs[i]
        pc.set_facecolor(palette(c))
        pc.set_edgecolor(palette(c))
        pc.set_alpha(0.9)
    
    alpha = 0.7
    vparts['cbars'].set_color('black')
    vparts['cbars'].set_alpha(0.3)
    vparts['cbars'].set_linestyle('--')
    
    vparts['cmins'].set_color('black')
    vparts['cmins'].set_alpha(alpha)
    
    vparts['cmaxes'].set_color('black')
    vparts['cmaxes'].set_alpha(alpha)
    
    vparts['cmedians'].set_color('black')
    vparts['cmedians'].set_alpha(alpha)
    return vparts['bodies']