in compert/plotting.py [0:0]
def plot_dose_response(df,
contvar_key,
perturbation_key,
df_ref=None,
response_name='response',
use_ref_response=False,
palette=None,
col_dict=None,
fontsize=8,
measured_points=None,
interpolate=True,
f1=7,
f2=3.,
bbox=(1.35, 1.),
ref_name='origin',
title_name='None',
plot_vertical=True,
fname=None,
logscale=None,
xlabelname=None,
format='png'):
"""Plotting decoding of the response with respect to dose.
Params
------
df : `DataFrame`
Table with columns=[perturbation_key, contvar_key, response_name].
The last column is always "response".
contvar_key : str
Name of the column in df for values to use for x axis.
perturbation_key : str
Name of the column in df for the perturbation or covariate to plot.
response_name: str (default: response)
Name of the column in df for values to use for y axis.
df_ref : `DataFrame` (default: None)
Table with the same columns as in df to plot ground_truth or another
condition for comparison. Could
also be used to just extract reference values for x-axis.
use_ref_response : bool (default: False)
A flag indicating if to use values for y axis from df_ref (True) or j
ust to extract reference values for x-axis.
col_dict : dictionary (default: None)
Dictionary with colors for each value in perturbation_key.
bbox : tuple (default: (1.35, 1.))
Coordinates to adjust the legend.
plot_vertical : boolean (default: False)
Flag if to plot reference values for x axis from df_ref dataframe.
f1 : float (default: 7.0))
Width in inches for the plot.
f2 : float (default: 3.0))
Hight in inches for the plot.
fname : str (default: None)
Name of the file to export the plot. The name comes without format
extension.
format : str (default: png)
Format for the file to export the plot.
"""
sns.set_style("white")
if use_ref_response and not (df_ref is None):
df[ref_name] = 'predictions'
df_ref[ref_name] = 'observations'
if interpolate:
df_plt = pd.concat([df, df_ref])
else:
df_plt = df
else:
df_plt = df
atomic_drugs = np.unique(df[perturbation_key].values)
if palette is None:
current_palette = get_palette(len(list(atomic_drugs)))
if col_dict is None:
col_dict = dict(
zip(
list(atomic_drugs),
current_palette
)
)
fig = plt.figure(figsize=(f1, f2))
ax = plt.gca()
if use_ref_response:
sns.lineplot(
x=contvar_key,
y=response_name,
palette=col_dict,
hue=perturbation_key,
style=ref_name,
dashes=[(1, 0), (2, 1)],
legend='full',
style_order=['predictions', 'observations'],
data=df_plt, ax=ax)
df_ref = df_ref.replace('training_treated', 'train')
sns.scatterplot(
x=contvar_key,
y=response_name,
hue='split',
size='num_cells',
sizes=(10, 100),
alpha=1.,
palette={'train': '#000000', 'training': '#000000', 'ood': '#e41a1c'},
data=df_ref, ax=ax)
ax.legend_.remove()
else:
sns.lineplot(x=contvar_key, y=response_name,
palette=col_dict,
hue=perturbation_key,
data=df_plt, ax=ax)
ax.legend(
loc='upper right',
bbox_to_anchor=bbox,
fontsize=fontsize)
if not (title_name is None):
ax.set_title(title_name, fontsize=fontsize, fontweight='bold')
ax.grid('off')
if xlabelname is None:
ax.set_xlabel(contvar_key, fontsize=fontsize)
else:
ax.set_xlabel(xlabelname, fontsize=fontsize)
ax.set_ylabel(f"{response_name}", fontsize=fontsize)
ax.xaxis.set_tick_params(labelsize=fontsize)
ax.yaxis.set_tick_params(labelsize=fontsize)
if not (logscale is None):
ax.set_xticks(np.log10(logscale))
ax.set_xticklabels(logscale, rotation=90)
if not (df_ref is None):
atomic_drugs=np.unique(df_ref[perturbation_key].values)
for drug in atomic_drugs:
x = df_ref[df_ref[perturbation_key] == drug][contvar_key].values
m1 = np.min(df[df[perturbation_key] == drug][response_name].values)
m2 = np.max(df[df[perturbation_key] == drug][response_name].values)
if plot_vertical:
for x_dot in x:
ax.plot([x_dot, x_dot], [m1, m2], ':', color='black',
linewidth=.5, alpha=0.5)
fig.tight_layout()
if fname:
plt.savefig(f'{fname}.{format}', format=format)
return fig