in compert/plotting.py [0:0]
def plot_contvar_response2D(self,
df_response2D,
df_ref=None,
levels=15,
figsize=(4,4),
xlims=(0, 1.03),
ylims=(0, 1.03),
palette="coolwarm",
response_name='response',
title_name=None,
fontsize=None,
postfix='',
filename=None,
alpha=0.4,
sizes=(40, 160),
logdose=False,
file_format='png'):
"""
Parameters
----------
df_response2D : pd.DataFrame
Data frame with responses of combinations with columns=(dose1, dose2,
response).
levels: int, optional (default: 15)
Number of levels for contour plot.
response_name : str (default: 'response')
Name of column in df_response to plot as response.
alpha: float (default: 0.4)
Transparency of the background contour.
figsize: tuple (default: (4,4))
Size of the figure in inches.
palette : dict, optional (default: None)
Colors dictionary for perturbations to plot.
title_name : str, optional (default: None)
Title for the plot.
postfix : str, optional (defualt: '')
Postfix to add to the output file name to save the model.
filename : str, optional (defualt: None)
Name of the file to save the plot. If None, will automatically
generate name from prefix file.
logdose: bool (default: False)
If True, dose values will be log10. 0 values will be mapped to
minumum value -1,e.g.
if smallest non-zero dose was 0.001, 0 will be mapped to -4.
"""
sns.set_style("white")
if (filename is None) and not (self.fileprefix is None):
filename = f'{self.fileprefix}_{postfix}response2D.png'
if fontsize is None:
fontsize = self.fontsize
x_name, y_name = df_response2D.columns[:2]
x = df_response2D[x_name].values
y = df_response2D[y_name].values
if logdose:
x = log10_with0(x)
y = log10_with0(y)
z = df_response2D[response_name].values
n = int(np.sqrt(len(x)))
X = x.reshape(n, n)
Y = y.reshape(n, n)
Z = z.reshape(n, n)
fig, ax = plt.subplots(figsize=figsize)
CS = ax.contourf(X,Y,Z, cmap=palette, levels=levels, alpha=alpha)
CS = ax.contour(X, Y, Z, levels=15, cmap=palette)
ax.clabel(CS, inline=1, fontsize=fontsize)
ax.set(xlim=(0, 1), ylim=(0, 1))
ax.axis("equal")
ax.axis("square")
ax.yaxis.set_tick_params(labelsize=fontsize)
ax.xaxis.set_tick_params(labelsize=fontsize)
ax.set_xlabel(x_name, fontsize=fontsize, fontweight="bold")
ax.set_ylabel(y_name, fontsize=fontsize, fontweight="bold")
ax.set_xlim(xlims)
ax.set_ylim(ylims)
# sns.despine(left=False, bottom=False, right=True)
sns.despine()
if not (df_ref is None):
sns.scatterplot(
x=x_name,
y=y_name,
hue='split',
size='num_cells',
sizes=sizes,
alpha=1.,
palette={'train': '#000000', 'training': '#000000', 'ood': '#e41a1c'},
data=df_ref, ax=ax)
ax.legend_.remove()
ax.set_title(title_name, fontweight="bold", fontsize=fontsize)
plt.tight_layout()
if filename:
save_to_file(fig, filename)