in src/plot_relaxed_pareto.py [0:0]
def do_pareto_plot(gammas, total_steps_per_gamma, max_grp_errs, pop_errs, trajectories, numsteps,
error_type, pop_error_type,
save_plots, dirname,
model_type,
use_input_commands,
data_name='', bonus_plot_list=None, show_basic_plots=False,
val_max_grp_errs=None, val_pop_errs=None, val_trajectories=None, val_bonus_plot_list=None,
test_size=0.0):
"""
Utility function used in main_driver to create a multi-trajectory plot over runs with a range of gamma values,
and traces the pareto curve of errors resulting mixture models.
Use argument `show_basic_plots` to enable scatter plots for pairwise relationships between population error,
max groups error, and gamma, of the final mixture models.
"""
figures = []
plt.ion()
# Setup strings for graph titles
dataset_string = f' on {data_name[0].upper() + data_name[1:]}' if data_name != '' else ''
# Get the pareto curve
pareto = get_pareto(pop_errs, max_grp_errs)
# Set pop_error string
pop_error_string = pop_error_type
if pop_error_type == 'Total':
pop_error_string = f'0/1 Loss'
if show_basic_plots:
if use_input_commands:
input('Press `Enter` to display first plot...')
figures.append(plt.figure())
plt.scatter(pop_errs, max_grp_errs)
plt.title(f'Pop Error vs. Max Group Error{dataset_string} \n {model_type}')
plt.xlabel(f'Pop Error ({pop_error_string})')
plt.ylabel(f'Max Group Error ({error_type})')
# Compute and plot pareto curve
if pareto is not None:
plt.plot(pareto[:, 0], pareto[:, 1], 'r--', lw=2, label='Pareto Curve', alpha=0.5)
plt.show()
if use_input_commands:
input('Next plot...')
figures.append(plt.figure())
plt.scatter(gammas, max_grp_errs)
plt.title(f'Gamma vs. Max Group Error{dataset_string} \n {model_type}')
plt.xlabel('Gamma')
plt.ylabel(f'Max Group Error ({error_type})')
plt.show()
if use_input_commands:
input('Next plot...')
figures.append(plt.figure())
plt.scatter(gammas, pop_errs)
plt.title(f'Gamma vs. Pop Error{dataset_string} \n {model_type}')
plt.xlabel('Gamma')
plt.ylabel(f'Pop Error ({pop_error_string})')
plt.show()
# Multi-trajectory plot
if use_input_commands:
input('Next plot...')
figures.append(plt.figure())
colors = [np.arange(1, total_steps) for total_steps in total_steps_per_gamma]
for (x, y), gamma, color in zip(trajectories, gammas, colors):
plt.scatter(x, y, c=color, s=2)
plt.scatter(x[0], y[0], c='m', s=20)
plt.annotate(f'gamma={gamma:.5f}', xy=(x[-1], y[-1]))
plt.title(f'Trajectories over {numsteps} Rounds{dataset_string} \n {model_type}')
plt.xlabel(f'Pop Error ({pop_error_string})')
plt.ylabel(f'Max Group Error ({error_type})')
# Add the pareto plot here as well
if pareto is not None:
plt.plot(pareto[:, 0], pareto[:, 1], 'r--', lw=2, label='Pareto Curve', alpha=0.5)
plt.show()
if show_basic_plots:
figure_names = ['PopError_vs_MaxGroupError', 'Gamma_vs_MaxGroupError', 'Gamma_vs_PopError',
'Trajectories_over_Gammas']
else:
figure_names = ['Trajectories_over_Gamma']
# Do the multi-trajectory plots for the additional error types
colors = [np.arange(1, total_steps) for total_steps in total_steps_per_gamma]
bonus_figures, bonus_names = \
plot_trajectories_from_bonus_plot_data(bonus_plot_list, gammas, model_type, error_type, numsteps,
total_steps_per_gamma,
use_input_commands)
figures.extend(bonus_figures)
figure_names.extend(bonus_names)
if val_max_grp_errs is not None and val_pop_errs is not None:
val_pareto = get_pareto(val_pop_errs, val_max_grp_errs)
if show_basic_plots:
# Validation Pop Error vs. Max Group Error
if use_input_commands:
input('Click enter to display first validation plot')
figures.append(plt.figure())
plt.scatter(val_pop_errs, val_max_grp_errs)
plt.title(f'Pop Error vs. Max Group Error{dataset_string} (Validation: {test_size}) \n {model_type}')
plt.xlabel(f'Pop Error ({pop_error_string})')
plt.ylabel(f'Max Group Error ({error_type})')
# Compute and plot pareto curve
if val_pareto is not None:
plt.plot(val_pareto[:, 0], val_pareto[:, 1], 'r--', lw=2, label='Pareto Curve', alpha=0.5)
plt.show()
# Validation Gamma vs. Max Group Error
if use_input_commands:
input('Next plot...')
figures.append(plt.figure())
plt.scatter(gammas, val_max_grp_errs)
plt.title(f'Gamma vs. Max Group Error{dataset_string} (Validation: {test_size}) \n {model_type}')
plt.xlabel('Gamma')
plt.ylabel(f'Max Group Error ({error_type})')
plt.show()
# Validation Gamma vs. Pop Error
if use_input_commands:
input('Next plot...')
figures.append(plt.figure())
plt.scatter(gammas, val_pop_errs)
plt.title(f'Gamma vs. Pop Error{dataset_string} (Validation: {test_size}) \n {model_type}')
plt.xlabel('Gamma')
plt.ylabel(f'Pop Error ({pop_error_string})')
plt.show()
# Validation Trajectory
if use_input_commands:
input('Next plot...')
figures.append(plt.figure())
colors = [np.arange(1, total_steps) for total_steps in total_steps_per_gamma]
for (x, y), gamma, color in zip(val_trajectories, gammas, colors):
plt.scatter(x, y, c=color, s=2)
plt.annotate(f'gamma={gamma:.5f}', xy=(x[-1], y[-1]))
plt.title(f'Trajectories over {numsteps} Rounds{dataset_string} (Validation: {test_size}) \n {model_type}')
plt.xlabel(f'Pop Error ({pop_error_string})')
plt.ylabel(f'Max Group Error ({error_type})')
if val_pareto is not None:
plt.plot(val_pareto[:, 0], val_pareto[:, 1], 'r--', lw=2, label='Pareto Curve', alpha=0.5)
plt.show()
if show_basic_plots:
figure_names.extend(['PopError_vs_MaxGroupError_Validation', 'Gamma_vs_MaxGroupError_Validation',
'Gamma_vs_PopError_Validation', 'Trajectories_over_Gammas_Validation'])
else:
figure_names.extend(['Trajectories_over_Gamma_Validation'])
# colors = [np.arange(1, total_steps) for total_steps in total_steps_per_gamma]
val_bonus_figures, val_bonus_names = \
plot_trajectories_from_bonus_plot_data(val_bonus_plot_list, gammas, model_type, error_type,
numsteps, total_steps_per_gamma, use_input_commands,
test_size=test_size)
figures.extend(val_bonus_figures)
val_bonus_names = [name + '_Validation' for name in val_bonus_names]
figure_names.extend(val_bonus_names)
if use_input_commands:
input('Quit')
if save_plots:
save_plots_to_os(figures, figure_names, dirname, True)
plt.close('all')