in expanded_checklist/checklist/graphs/graphs.py [0:0]
def plot_groups(
df, group_names, plot_name, out_dir,
scores=None, last_col="accumulated", tight=False,
scaling=Scaling.NONE):
try:
df = df.set_index('Metric')
except:
pass
if 'PosAvgEG$^P$' in df.index:
df = df.drop(['AvgEG$^P$'])
bcm_metrics = df[df['Type'] == "BCM"]
pcm_metrics = df[df['Type'] == "PCM"]
mcm_metrics = df[df['Type'] == "MCM"]
df = df.sort_values(
by=['ProbBased'], ascending=True)
df = pd.concat([bcm_metrics,
pcm_metrics,
mcm_metrics])
df.round(3)
metric_counts = [[len(bcm_metrics), len(pcm_metrics), len(mcm_metrics)]]
if group_names:
columns = group_names
if last_col and last_col in columns:
columns = [c for c in columns if c != last_col] + [last_col]
clean_columns = [x for x in columns if "DROP" not in x]
num_columns = len(clean_columns)
num_rows = len(df)
xticklabels = [c.replace("_", " ") for c in clean_columns]
yticklabels = "auto"
if 'Counterfactual' in df:
dfs = [
df[df['Counterfactual'] == True][columns],
df[df['Counterfactual'] == False][columns]
]
metric_counts = [
[
len(bcm_metrics[bcm_metrics['Counterfactual'] == True]),
len(pcm_metrics[pcm_metrics['Counterfactual'] == True]),
len(mcm_metrics[mcm_metrics['Counterfactual'] == True])
],
[
len(bcm_metrics[bcm_metrics['Counterfactual'] == False]),
len(pcm_metrics[pcm_metrics['Counterfactual'] == False]),
len(mcm_metrics[mcm_metrics['Counterfactual'] == False])
]
]
else:
dfs = [df[columns]]
else:
# 1D heatmap: no group results = counterfactual metrics
num_columns = len(df)
num_rows = 1
data = [list(df['Score'])]
xticklabels = list(df.index)
yticklabels = []
dfs = [data]
fig, axs = plt.subplots(
len(dfs), 1, sharex=False, sharey=False,
figsize=(0.8 * num_columns, 0.5 * num_rows),
gridspec_kw={'height_ratios': [len(x) for x in dfs]})
if type(axs) not in [list, np.ndarray]:
axs = [axs]
for i, (data, mcounts) in enumerate(zip(dfs, metric_counts)):
if i != 0:
xticks = []
else:
xticks = xticklabels
if type(scaling) != Scaling:
try:
scaling = Scaling(scaling)
except:
scaling = Scaling.NONE
cbar = False
vmin, vmax = -1, 1
accumulated_cols = None
if type(data) == list:
# no heatmap
scaled_data = [[0] * len(data[0])]
else:
if scaling == Scaling.NONE:
scaled_data = data
vmin, vmax = -0.2, 0.2
cbar = True
else:
# round to have clean results with scaling if all measurements are
# very close to 0
data = data.round(3)
accumulated_cols = [x for x in data.columns if 'accumulated' in x]
tmp_data = data.drop(columns=accumulated_cols)
if scaling == Scaling.UNIT_NORM:
x = tmp_data.values
x_scaled = preprocessing.normalize(x, norm='l2')
scaled_data = pd.DataFrame(x_scaled)
elif scaling == Scaling.MAX_ABS:
x = tmp_data.values.T
scaler = preprocessing.MaxAbsScaler()
x_scaled = scaler.fit_transform(x)
scaled_data = pd.DataFrame(x_scaled)
scaled_data = scaled_data.T
scaled_data.index = tmp_data.index
scaled_data.columns = tmp_data.columns
for ac in accumulated_cols:
scaled_data[ac] = [0 if not math.isnan(x) else None for x in data[ac]]
scaled_data = scaled_data[scaled_data.columns.drop(list(scaled_data.filter(regex='DROP')))]
data = data[data.columns.drop(list(data.filter(regex='DROP')))]
sns.heatmap(
scaled_data,
xticklabels=xticks,
yticklabels=yticklabels,
annot=data,
fmt='.3f',
cbar=cbar,
cmap=sns.diverging_palette(220, 20, 55 ,l=60, as_cmap=True),
#sns.color_palette("Reds", as_cmap=True),
vmax=vmax,
vmin=vmin,
ax=axs[i]
)
axs[i].tick_params('x', labelrotation=20)
#axs[i].tick_params('y', labelrotation=30)
axs[i].xaxis.tick_top()
axs[i].yaxis.set_label_text('')
if not group_names:
axs[i].vlines(
[mcounts[0], mcounts[0] + mcounts[1]],
*axs[i].get_ylim(),
colors='black', linestyles='dashed')
elif last_col and last_col in columns:
axs[i].vlines(
[num_columns - 1], *axs[i].get_ylim(),
colors='black', linestyles='dashed')
axs[i].hlines(
[mcounts[0], mcounts[0] + mcounts[1]],
*axs[i].get_xlim(),
colors='black', linestyles='dashed')
fname = f'{out_dir}/{plot_name}'
# plt.xticks(rotation=20)
if tight:
plt.tight_layout()
plt.savefig(fname, bbox_inches='tight', pad_inches=0)
plt.show()