in src/plotting.py [0:0]
def do_plotting(display_plots, save_plots, use_input_commands, numsteps, group_names, group_type,
show_legend, error_type, data_name, model_string,
agg_poperrs, agg_grouperrs, groupweights, pop_error_type, bonus_plots,
dirname, multi_group=False,
validation=False, equal_error=False):
"""
Helper function for minimaxML that creates the relevant plots for a single run of the simulation.
"""
# Create a list of all figures we want to save for later which will be passed into a function
figures = []
figure_names = ['PopError_vs_Rounds', 'GroupError_vs_Rounds', 'GroupWeights_vs_Rounds', 'Trajectory_Plot']
# Combine all the existing arrays as necessary by separating all subgroups as unqiue groups
if multi_group:
num_group_types = len(agg_grouperrs) # list of numpy arrays
agg_grouperrs = np.column_stack(agg_grouperrs) # vertically stck the groups errs
if not validation:
groupweights = np.column_stack(groupweights) # vertically stack the weights for each groups
stacked_group_names = [] # stack the groups errors
for i in range(num_group_types):
g_type = group_type[i] if num_group_types > 1 else ''
stacked_group_names.extend([g_type + ': ' + name for name in group_names[i]])
group_names = stacked_group_names
# End of multi-groups adjustments
if group_type is not None:
# print(f'Here are the plots for groups based on: {group_type}')
pass
if use_input_commands and display_plots:
input("Press `Enter` to show first plot... ")
# Setup strings for graph titles
dataset_string = f' on {data_name[0].upper() + data_name[1:]}' # Set the first letter to capital if it isn't
plt.ion()
# Average Pop error vs. Rounds
figures.append(plt.figure()) # Creates figure and adds it to list of figures
plt.plot(agg_poperrs)
plt.title(f'Average Population Error ({pop_error_type})' + dataset_string + model_string)
plt.xlabel('Steps')
plt.ylabel(f'Average Population Error ({pop_error_type})')
if display_plots:
plt.show()
if use_input_commands and display_plots:
input("Next plot...")
# Group Errors vs. Rounds
figures.append(plt.figure()) # Create figure and append to list
for g in range(0, len(group_names)):
# Plots the groups with appropriate label
plt.plot(agg_grouperrs[:, g], label=group_names[g])
if show_legend:
plt.legend(loc='upper right')
plt.title(f'Group Errors ({error_type})' + dataset_string + model_string)
plt.xlabel('Steps')
plt.ylabel(f'Group Errors ({error_type})')
if display_plots:
plt.show()
if use_input_commands and display_plots:
input("Next plot...")
# Group Weights vs. Rounds
if not validation and groupweights is not None: # Groupweights aren't a part of validation
figures.append(plt.figure()) # Create figure and append to list
for g in range(0, len(group_names)):
plt.plot(groupweights[:, g], label=group_names[g])
if show_legend:
plt.legend(loc='upper right')
plt.title(f'Group Weights' + dataset_string + model_string)
plt.xlabel('Steps')
plt.ylabel('Group Weights')
if display_plots:
plt.show()
if use_input_commands and display_plots:
input("Next plot...")
# Trajectory Plot with Pareto Curve
figures.append(plt.figure())
x = agg_poperrs
y = np.max(agg_grouperrs, axis=1)
points = np.zeros((len(x), 2))
points[:, 0] = x
points[:, 1] = y
colors = np.arange(1, numsteps)
plt.scatter(x, y, c=colors, s=2, label='Trajectory of Mixtures')
plt.scatter(x[0], y[0], c='m', s=40, label='Starting point') # Make the first point big and pink
plt.title(f'Trajectory over {numsteps - 1} rounds' + dataset_string + model_string)
plt.xlabel(f'Population Error ({pop_error_type})')
plt.ylabel(f'Max Group Error ({error_type})')
if display_plots:
plt.show()
for err_type, grp_errs, pop_errs, pop_err_type in bonus_plots:
if use_input_commands and display_plots:
input(f"Next bonus plot for error type {err_type}...")
# Group Errors vs. Rounds
figures.append(plt.figure()) # Create figure and append to list
figure_names.append(f'GroupError_vs_Rounds_({err_type if err_type != "0/1 Loss" else "0-1 Loss"})')
for g in range(0, len(group_names)):
# Plots the groups with appropriate label
plt.plot(grp_errs[:, g], label=group_names[g])
if show_legend:
plt.legend(loc='upper right')
plt.title(f'Group Errors ({err_type})' + dataset_string + model_string)
plt.xlabel('Steps')
plt.ylabel(f'Group Errors ({err_type})')
if display_plots:
plt.show()
if use_input_commands and display_plots:
input("Next bonus plot (trajectory)...")
figures.append(plt.figure())
figure_names.append(f'Trajectory_({err_type if (err_type != "0/1 Loss") else "0-1 Loss"})')
x = pop_errs
y = np.max(grp_errs, axis=1)
points = np.zeros((len(x), 2))
points[:, 0] = x
points[:, 1] = y
colors = np.arange(1, numsteps)
plt.scatter(x, y, c=colors, s=2, label='Trajectory of Mixtures')
plt.scatter(x[0], y[0], c='m', s=40, label='Starting point') # Make the first point big and pink
plt.title(f'Trajectory over {numsteps - 1} rounds' + dataset_string + model_string)
plt.xlabel(f'Population Error ({err_type})')
plt.ylabel(f'Max Group Error ({pop_err_type})')
if display_plots:
plt.show()
if use_input_commands and display_plots:
input("Quit")
# Update the names if doing valiadtion
if validation:
figure_names = [name + '_Validation' for name in figure_names]
# Now we have a list of plots: `figures` we can save
if save_plots:
save_plots_to_os(figures, figure_names, dirname)
plt.close('all')