in src/plot_relaxed_pareto.py [0:0]
def plot_trajectories_from_bonus_plot_data(bonus_plot_list, gammas, model_type, error_type, numsteps,
total_steps_per_gamma,
use_input_commands, test_size=0.0, data_name=''):
"""
:param bonus_plot_list: List (over gammas) of lists of tuples corresponding to the bonus plots for each run
:param gammas: list of gammas corresponding to each round
:return figures, names: list of figures and their names
"""
figures = []
names = []
# Set the first letter to capital if it isn't
dataset_string = f' on {data_name[0].upper() + data_name[1:]}' if data_name != '' else ''
try:
num_bonus_plots = len(bonus_plot_list[0]) # Number of 4-tuples (bonus plots) per value of gamma
except:
print(bonus_plot_list)
warnings.warn('WARNING: Could not index into bonus plots. Skipping and continuing...')
num_bonus_plots = 0
# Iterate over the number of bonus plots per individual run
for plot_index in range(num_bonus_plots):
if use_input_commands:
input('Next bonus plot')
figures.append(plt.figure()) # One figure for 'type' of multi trajectory plot
# Keep ararys to track the endpoints of the trajectories and eventually plot pareto curve
endpoints_x = []
endpoints_y = []
# Determine values for the name, title, and axes of the multi-trajectory plot
err_type, _, _, pop_err_type = bonus_plot_list[0][plot_index]
names.append(f'Multi_Trajectory_Bonus_Plot_for_'
f'{err_type if err_type != "0/1 Loss" else "0-1 Loss"}_Group_Error')
loss_string = ''
if error_type in ['FP', 'FN']:
loss_string = f'{error_type} Loss'
elif error_type.endswith('Log-Loss'):
loss_string = error_type
elif error_type == 'Total':
loss_string = f'0/1 Loss'
# Rename 'total' error to 0/1 Loss for plotting
err_string = err_type
if err_type == 'Total':
err_string = f'0/1 Loss'
pop_err_string = pop_err_type
if pop_err_type == 'Total':
pop_err_string = f'0/1 Loss'
validation_string = '' if test_size == 0.0 else f'(Validation: {test_size})'
plt.title(f'Trajectories over {numsteps} Rounds{dataset_string}' + validation_string +
f'\n {model_type} weighted on ' + loss_string)
plt.xlabel(f'Pop Error ({pop_err_string})')
plt.ylabel(f'Max Group Error ({err_string})')
# Plot the trajectories for the 'plot_index'-th error type over all gammas
for single_run_bonus_plot_tuples, gamma, total_steps in zip(bonus_plot_list, gammas, total_steps_per_gamma):
err_type, grp_errs, pop_errs, pop_err_type = single_run_bonus_plot_tuples[plot_index]
x = pop_errs
y = np.max(grp_errs, axis=1)
plt.scatter(x, y, c=np.arange(1, total_steps), s=2) # Plot the individual trajectory
plt.scatter(x[0], y[0], c='m', s=20) # Add magenta starting point
plt.annotate(f'gamma={gamma:.5f}', xy=(x[-1], y[-1]))
# Add the endpoints for the pareto curve
endpoints_x.append(x[-1])
endpoints_y.append(y[-1])
# Compute and plot pareto curve
pareto = get_pareto(endpoints_x, endpoints_y)
if pareto is not None:
plt.plot(pareto[:, 0], pareto[:, 1], 'r--', lw=2, label='Pareto Curve', alpha=0.5)
plt.show()
return figures, names