def plot_groups()

in expanded_checklist/checklist/graphs/graphs.py [0:0]


def plot_groups(
        df, group_names, plot_name, out_dir,
        scores=None, last_col="accumulated", tight=False,
        scaling=Scaling.NONE):
    try:
        df = df.set_index('Metric')
    except:
        pass

    if 'PosAvgEG$^P$' in df.index:
        df = df.drop(['AvgEG$^P$'])

    bcm_metrics = df[df['Type'] == "BCM"]
    pcm_metrics = df[df['Type'] == "PCM"]
    mcm_metrics = df[df['Type'] == "MCM"]

    df = df.sort_values(
        by=['ProbBased'], ascending=True)

    df = pd.concat([bcm_metrics,
                    pcm_metrics,
                    mcm_metrics])
    df.round(3)

    metric_counts = [[len(bcm_metrics), len(pcm_metrics), len(mcm_metrics)]]
    if group_names:
        columns = group_names
        if last_col and last_col in columns:
            columns = [c for c in columns if c != last_col] + [last_col]

        clean_columns = [x for x in columns if "DROP" not in x]
        num_columns = len(clean_columns)
        num_rows = len(df)
        xticklabels = [c.replace("_", " ") for c in clean_columns]

        yticklabels = "auto"
        if 'Counterfactual' in df:
            dfs = [
                df[df['Counterfactual'] == True][columns],
                df[df['Counterfactual'] == False][columns]
            ]
            metric_counts = [
                [
                    len(bcm_metrics[bcm_metrics['Counterfactual'] == True]),
                    len(pcm_metrics[pcm_metrics['Counterfactual'] == True]),
                    len(mcm_metrics[mcm_metrics['Counterfactual'] == True])
                ],
                [
                    len(bcm_metrics[bcm_metrics['Counterfactual'] == False]),
                    len(pcm_metrics[pcm_metrics['Counterfactual'] == False]),
                    len(mcm_metrics[mcm_metrics['Counterfactual'] == False])
                ]
            ]
        else:
            dfs = [df[columns]]

    else:
        # 1D heatmap: no group results = counterfactual metrics
        num_columns = len(df)
        num_rows = 1
        data = [list(df['Score'])]
        xticklabels = list(df.index)
        yticklabels = []
        dfs = [data]

    fig, axs = plt.subplots(
        len(dfs), 1, sharex=False, sharey=False,
        figsize=(0.8 * num_columns, 0.5 * num_rows),
        gridspec_kw={'height_ratios': [len(x) for x in dfs]})

    if type(axs) not in [list, np.ndarray]:
        axs = [axs]

    for i, (data, mcounts) in enumerate(zip(dfs, metric_counts)):
        if i != 0:
            xticks = []
        else:
            xticks = xticklabels

        if type(scaling) != Scaling:
            try:
                scaling = Scaling(scaling)
            except:
                scaling = Scaling.NONE

        cbar = False
        vmin, vmax = -1, 1

        accumulated_cols = None
        if type(data) == list:
            # no heatmap
            scaled_data = [[0] * len(data[0])]
        else:
            if scaling == Scaling.NONE:
                scaled_data = data
                vmin, vmax = -0.2, 0.2
                cbar = True
            else:
                # round to have clean results with scaling if all measurements are
                # very close to 0
                data = data.round(3)

                accumulated_cols = [x for x in data.columns if 'accumulated' in x]
                tmp_data = data.drop(columns=accumulated_cols)

                if scaling == Scaling.UNIT_NORM:
                    x = tmp_data.values
                    x_scaled = preprocessing.normalize(x, norm='l2')
                    scaled_data = pd.DataFrame(x_scaled)
                elif scaling == Scaling.MAX_ABS:
                    x = tmp_data.values.T
                    scaler = preprocessing.MaxAbsScaler()
                    x_scaled = scaler.fit_transform(x)
                    scaled_data = pd.DataFrame(x_scaled)
                    scaled_data = scaled_data.T
                scaled_data.index = tmp_data.index
                scaled_data.columns = tmp_data.columns

                for ac in accumulated_cols:
                    scaled_data[ac] = [0 if not math.isnan(x) else None for x in data[ac]]

            scaled_data = scaled_data[scaled_data.columns.drop(list(scaled_data.filter(regex='DROP')))]
            data = data[data.columns.drop(list(data.filter(regex='DROP')))]

        sns.heatmap(
            scaled_data,
            xticklabels=xticks,
            yticklabels=yticklabels,
            annot=data,
            fmt='.3f',
            cbar=cbar,
            cmap=sns.diverging_palette(220, 20, 55 ,l=60, as_cmap=True),
            #sns.color_palette("Reds", as_cmap=True),
            vmax=vmax,
            vmin=vmin,
            ax=axs[i]
        )
        axs[i].tick_params('x', labelrotation=20)
        #axs[i].tick_params('y', labelrotation=30)

        axs[i].xaxis.tick_top()
        axs[i].yaxis.set_label_text('')

        if not group_names:
            axs[i].vlines(
                [mcounts[0], mcounts[0] + mcounts[1]],
                *axs[i].get_ylim(),
                colors='black', linestyles='dashed')
        elif last_col and last_col in columns:
            axs[i].vlines(
                [num_columns - 1], *axs[i].get_ylim(),
                colors='black', linestyles='dashed')

        axs[i].hlines(
            [mcounts[0], mcounts[0] + mcounts[1]],
            *axs[i].get_xlim(),
            colors='black', linestyles='dashed')

    fname = f'{out_dir}/{plot_name}'
    # plt.xticks(rotation=20)

    if tight:
        plt.tight_layout()
    plt.savefig(fname, bbox_inches='tight', pad_inches=0)
    plt.show()