def plot_dose_response()

in compert/plotting.py [0:0]


def plot_dose_response(df,
                       contvar_key,
                       perturbation_key,
                       df_ref=None,
                       response_name='response',
                       use_ref_response=False,
                       palette=None,
                       col_dict=None,
                       fontsize=8,
                       measured_points=None,
                       interpolate=True,
                       f1=7,
                       f2=3.,
                       bbox=(1.35, 1.),
                       ref_name='origin',
                       title_name='None',
                       plot_vertical=True,
                       fname=None,
                       logscale=None,
                       xlabelname=None,
                       format='png'):

    """Plotting decoding of the response with respect to dose.

    Params
    ------
    df : `DataFrame`
        Table with columns=[perturbation_key, contvar_key, response_name].
        The last column is always "response".
    contvar_key : str
        Name of the column in df for values to use for x axis.
    perturbation_key : str
        Name of the column in df for the perturbation or covariate to plot.
    response_name: str (default: response)
        Name of the column in df for values to use for y axis.
    df_ref : `DataFrame` (default: None)
        Table with the same columns as in df to plot ground_truth or another
        condition for comparison. Could
        also be used to just extract reference values for x-axis.
    use_ref_response : bool (default: False)
        A flag indicating if to use values for y axis from df_ref (True) or j
        ust to extract reference values for x-axis.
    col_dict : dictionary (default: None)
        Dictionary with colors for each value in perturbation_key.
    bbox : tuple (default: (1.35, 1.))
        Coordinates to adjust the legend.
    plot_vertical : boolean (default: False)
        Flag if to plot reference values for x axis from df_ref dataframe.
    f1 : float (default: 7.0))
        Width in inches for the plot.
    f2 : float (default: 3.0))
        Hight in inches for the plot.
    fname : str (default: None)
        Name of the file to export the plot. The name comes without format
        extension.
    format : str (default: png)
        Format for the file to export the plot.
    """
    sns.set_style("white")
    if use_ref_response and not (df_ref is None):
        df[ref_name] = 'predictions'
        df_ref[ref_name] = 'observations'
        if interpolate:
            df_plt = pd.concat([df, df_ref])
        else:
            df_plt = df
    else:
        df_plt = df

    atomic_drugs = np.unique(df[perturbation_key].values)

    if palette is None:
        current_palette = get_palette(len(list(atomic_drugs)))

    if col_dict is None:
        col_dict = dict(
            zip(
                list(atomic_drugs),
                current_palette
            )
        )

    fig = plt.figure(figsize=(f1, f2))
    ax = plt.gca()

    if use_ref_response:
        sns.lineplot(
                x=contvar_key,
                y=response_name,
                palette=col_dict,
                hue=perturbation_key,
                style=ref_name,
                dashes=[(1, 0), (2, 1)],
                legend='full',
                style_order=['predictions', 'observations'],
             data=df_plt, ax=ax)

        df_ref = df_ref.replace('training_treated', 'train')
        sns.scatterplot(
            x=contvar_key,
            y=response_name,
            hue='split',
            size='num_cells',
            sizes=(10, 100),
            alpha=1.,
            palette={'train': '#000000', 'training': '#000000', 'ood': '#e41a1c'},
            data=df_ref, ax=ax)

        ax.legend_.remove()
    else:
        sns.lineplot(x=contvar_key, y=response_name,
                palette=col_dict,
                hue=perturbation_key,
             data=df_plt, ax=ax)
        ax.legend(
            loc='upper right',
            bbox_to_anchor=bbox,
            fontsize=fontsize)

    if not (title_name is None):
        ax.set_title(title_name, fontsize=fontsize, fontweight='bold')
    ax.grid('off')

    if xlabelname is None:
        ax.set_xlabel(contvar_key, fontsize=fontsize)
    else:
        ax.set_xlabel(xlabelname, fontsize=fontsize)

    ax.set_ylabel(f"{response_name}", fontsize=fontsize)

    ax.xaxis.set_tick_params(labelsize=fontsize)
    ax.yaxis.set_tick_params(labelsize=fontsize)

    if not (logscale is None):
        ax.set_xticks(np.log10(logscale))
        ax.set_xticklabels(logscale, rotation=90)

    if not (df_ref is None):
        atomic_drugs=np.unique(df_ref[perturbation_key].values)
        for drug in atomic_drugs:
            x = df_ref[df_ref[perturbation_key] == drug][contvar_key].values
            m1 = np.min(df[df[perturbation_key] == drug][response_name].values)
            m2 = np.max(df[df[perturbation_key] == drug][response_name].values)

            if plot_vertical:
                for x_dot in x:
                    ax.plot([x_dot, x_dot], [m1, m2], ':', color='black',
                        linewidth=.5, alpha=0.5)

    fig.tight_layout()
    if fname:
        plt.savefig(f'{fname}.{format}', format=format)

    return fig