in compert/plotting.py [0:0]
def mean_plot(
adata,
pred,
condition_key,
exp_key,
path_to_save="./reg_mean.pdf",
gene_list=None,
deg_list=None,
show=False,
title=None,
verbose=False,
x_coeff=0.30,
y_coeff=0.8,
fontsize=11,
R2_type="R2",
figsize=(3.5, 3.5),
**kwargs
):
"""
Plots mean matching.
# Parameters
adata: `~anndata.AnnData`
Contains real v
pred: `~anndata.AnnData`
Contains predicted values.
condition_key: Str
adata.obs key to look for x-axis and y-axis condition
exp_key: Str
Condition in adata.obs[condition_key] to be ploted
path_to_save: basestring
Path to save the plot.
gene_list: list
List of gene names to be plotted.
deg_list: list
List of DEGs to compute R2
show: boolean
if True plots the figure
Verbose: boolean
If true prints the value
title: Str
Title of the plot
x_coeff: float
Shifts R2 text horizontally by x_coeff
y_coeff: float
Shifts R2 text vertically by y_coeff
show: bool
if `True`: will show to the plot after saving it.
fontsize: int
Font size for R2 texts
R2_type: Str
How to compute R2 value, should be either Pearson R2 or R2 (sklearn)
Returns:
Calluated R2 values
"""
r2_types = ['R2', 'Pearson R2']
if R2_type not in r2_types:
raise ValueError("R2 caclulation should be one of" + str(r2_types))
if sparse.issparse(adata.X):
adata.X = adata.X.A
if sparse.issparse(pred.X):
pred.X = pred.X.A
diff_genes = deg_list
real = adata[adata.obs[condition_key] == exp_key]
pred = pred[pred.obs[condition_key] == exp_key]
if diff_genes is not None:
if hasattr(diff_genes, "tolist"):
diff_genes = diff_genes.tolist()
real_diff = adata[:, diff_genes][adata.obs[condition_key] == exp_key]
pred_diff = pred[:, diff_genes][pred.obs[condition_key] == exp_key]
x_diff = np.average(pred_diff.X, axis=0)
y_diff = np.average(real_diff.X, axis=0)
if R2_type == "R2":
r2_diff = r2_score(y_diff, x_diff)
if R2_type == "Pearson R2":
m, b, pearson_r_diff, p_value_diff, std_err_diff =\
stats.linregress(y_diff, x_diff)
r2_diff = pearson_r_diff**2
if verbose:
print(f'Top {len(diff_genes)} DEGs var: ', r2_diff)
x = np.average(pred.X, axis=0)
y = np.average(real.X, axis=0)
if R2_type == "R2":
r2 = r2_score(y, x)
if R2_type == "Pearson R2":
m, b, pearson_r, p_value, std_err = stats.linregress(y, x)
r2 = pearson_r**2
if verbose:
print('All genes var: ', r2)
df = pd.DataFrame({f'{exp_key}_true': x, f'{exp_key}_pred': y})
plt.figure(figsize=figsize)
ax = sns.regplot(x=f'{exp_key}_true', y=f'{exp_key}_pred', data=df)
ax.tick_params(labelsize=fontsize)
if "range" in kwargs:
start, stop, step = kwargs.get("range")
ax.set_xticks(np.arange(start, stop, step))
ax.set_yticks(np.arange(start, stop, step))
ax.set_xlabel('true', fontsize=fontsize)
ax.set_ylabel('pred', fontsize=fontsize)
if gene_list is not None:
for i in gene_list:
j = adata.var_names.tolist().index(i)
x_bar = x[j]
y_bar = y[j]
plt.text(x_bar, y_bar, i, fontsize=fontsize, color="black")
plt.plot(x_bar, y_bar, 'o', color="red", markersize=5)
if title is None:
plt.title(f"", fontsize=fontsize, fontweight="bold")
else:
plt.title(title, fontsize=fontsize, fontweight="bold")
ax.text(max(x) - max(x) * x_coeff, max(y) - y_coeff * max(y),
r'$\mathrm{R^2_{\mathrm{\mathsf{all\ genes}}}}$= ' + f"{r2:.2f}",
fontsize=fontsize)
if diff_genes is not None:
ax.text(max(x) - max(x) * x_coeff, max(y) - (y_coeff + 0.15) * max(y),
r'$\mathrm{R^2_{\mathrm{\mathsf{DEGs}}}}$= ' + f"{r2_diff:.2f}",
fontsize=fontsize)
plt.savefig(f"{path_to_save}", bbox_inches='tight', dpi=100)
if show:
plt.show()
plt.close()
if diff_genes is not None:
return r2, r2_diff
else:
return r2