def mean_plot()

in compert/plotting.py [0:0]


def mean_plot(
    adata,
    pred,
    condition_key,
    exp_key,
    path_to_save="./reg_mean.pdf",
    gene_list=None,
    deg_list=None,
    show=False,
    title=None,
    verbose=False,
    x_coeff=0.30,
    y_coeff=0.8,
    fontsize=11,
    R2_type="R2",
    figsize=(3.5, 3.5),
    **kwargs
    ):
    """
    Plots mean matching.

    # Parameters
    adata: `~anndata.AnnData`
        Contains real v
    pred: `~anndata.AnnData`
        Contains predicted values.
    condition_key: Str
        adata.obs key to look for x-axis and y-axis condition
    exp_key: Str
        Condition in adata.obs[condition_key] to be ploted
    path_to_save: basestring
        Path to save the plot.
    gene_list: list
        List of gene names to be plotted.
    deg_list: list
        List of DEGs to compute R2
    show: boolean
        if True plots the figure
    Verbose: boolean
        If true prints the value
    title: Str
        Title of the plot
    x_coeff: float
        Shifts R2 text horizontally by x_coeff
    y_coeff: float
        Shifts R2 text vertically by y_coeff
    show: bool
        if `True`: will show to the plot after saving it.
    fontsize: int
        Font size for R2 texts
    R2_type: Str
        How to compute R2 value, should be either Pearson R2 or R2 (sklearn)

    Returns:
    Calluated R2 values
    """

    r2_types = ['R2', 'Pearson R2']
    if R2_type not in r2_types:
        raise ValueError("R2 caclulation should be one of" + str(r2_types))
    if sparse.issparse(adata.X):
        adata.X = adata.X.A
    if sparse.issparse(pred.X):
        pred.X = pred.X.A
    diff_genes = deg_list
    real = adata[adata.obs[condition_key] == exp_key]
    pred = pred[pred.obs[condition_key] == exp_key]
    if diff_genes is not None:
        if hasattr(diff_genes, "tolist"):
            diff_genes = diff_genes.tolist()
        real_diff = adata[:, diff_genes][adata.obs[condition_key] == exp_key]
        pred_diff = pred[:, diff_genes][pred.obs[condition_key] == exp_key]
        x_diff = np.average(pred_diff.X, axis=0)
        y_diff = np.average(real_diff.X, axis=0)
        if R2_type == "R2":
            r2_diff = r2_score(y_diff, x_diff)
        if R2_type == "Pearson R2":
            m, b, pearson_r_diff, p_value_diff, std_err_diff =\
                stats.linregress(y_diff, x_diff)
            r2_diff = pearson_r_diff**2
        if verbose:
            print(f'Top {len(diff_genes)} DEGs var: ', r2_diff)
    x = np.average(pred.X, axis=0)
    y = np.average(real.X, axis=0)
    if R2_type == "R2":
        r2 = r2_score(y, x)
    if R2_type == "Pearson R2":
        m, b, pearson_r, p_value, std_err = stats.linregress(y, x)
        r2 = pearson_r**2
    if verbose:
        print('All genes var: ', r2)
    df = pd.DataFrame({f'{exp_key}_true': x, f'{exp_key}_pred': y})

    plt.figure(figsize=figsize)
    ax = sns.regplot(x=f'{exp_key}_true', y=f'{exp_key}_pred', data=df)
    ax.tick_params(labelsize=fontsize)
    if "range" in kwargs:
        start, stop, step = kwargs.get("range")
        ax.set_xticks(np.arange(start, stop, step))
        ax.set_yticks(np.arange(start, stop, step))
    ax.set_xlabel('true', fontsize=fontsize)
    ax.set_ylabel('pred', fontsize=fontsize)
    if gene_list is not None:
        for i in gene_list:
            j = adata.var_names.tolist().index(i)
            x_bar = x[j]
            y_bar = y[j]
            plt.text(x_bar, y_bar, i, fontsize=fontsize, color="black")
            plt.plot(x_bar, y_bar, 'o', color="red", markersize=5)
    if title is None:
        plt.title(f"", fontsize=fontsize, fontweight="bold")
    else:
        plt.title(title, fontsize=fontsize, fontweight="bold")
    ax.text(max(x) - max(x) * x_coeff, max(y) - y_coeff * max(y),
            r'$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= ' + f"{r2:.2f}",
            fontsize=fontsize)
    if diff_genes is not None:
        ax.text(max(x) - max(x) * x_coeff, max(y) - (y_coeff + 0.15) * max(y),
                r'$\mathrm{R^2_{\mathrm{\mathsf{DEGs}}}}$= ' + f"{r2_diff:.2f}",
                fontsize=fontsize)
    plt.savefig(f"{path_to_save}", bbox_inches='tight', dpi=100)
    if show:
        plt.show()
    plt.close()
    if diff_genes is not None:
        return r2, r2_diff
    else:
        return r2