def plot_trajectories_from_bonus_plot_data()

in src/plot_relaxed_pareto.py [0:0]


def plot_trajectories_from_bonus_plot_data(bonus_plot_list, gammas, model_type, error_type, numsteps,
                                           total_steps_per_gamma,
                                           use_input_commands, test_size=0.0, data_name=''):
    """
    :param bonus_plot_list: List (over gammas) of lists of tuples corresponding to the bonus plots for each run
    :param gammas: list of gammas corresponding to each round
    :return figures, names: list of figures and their names
    """
    figures = []
    names = []

    # Set the first letter to capital if it isn't
    dataset_string = f' on {data_name[0].upper() + data_name[1:]}' if data_name != '' else ''

    try:
        num_bonus_plots = len(bonus_plot_list[0])  # Number of 4-tuples (bonus plots) per value of gamma
    except:
        print(bonus_plot_list)
        warnings.warn('WARNING: Could not index into bonus plots. Skipping and continuing...')
        num_bonus_plots = 0

    # Iterate over the number of bonus plots per individual run
    for plot_index in range(num_bonus_plots):

        if use_input_commands:
            input('Next bonus plot')

        figures.append(plt.figure())  # One figure for 'type' of multi trajectory plot

        # Keep ararys to track the endpoints of the trajectories and eventually plot pareto curve
        endpoints_x = []
        endpoints_y = []

        # Determine values for the name, title, and axes of the multi-trajectory plot
        err_type, _, _, pop_err_type = bonus_plot_list[0][plot_index]
        names.append(f'Multi_Trajectory_Bonus_Plot_for_'
                     f'{err_type if err_type != "0/1 Loss" else "0-1 Loss"}_Group_Error')

        loss_string = ''
        if error_type in ['FP', 'FN']:
            loss_string = f'{error_type} Loss'
        elif error_type.endswith('Log-Loss'):
            loss_string = error_type
        elif error_type == 'Total':
            loss_string = f'0/1 Loss'

        # Rename 'total' error to 0/1 Loss for plotting
        err_string = err_type
        if err_type == 'Total':
            err_string = f'0/1 Loss'

        pop_err_string = pop_err_type
        if pop_err_type == 'Total':
            pop_err_string = f'0/1 Loss'

        validation_string = '' if test_size == 0.0 else f'(Validation: {test_size})'

        plt.title(f'Trajectories over {numsteps} Rounds{dataset_string}' + validation_string +
                  f'\n {model_type} weighted on ' + loss_string)
        plt.xlabel(f'Pop Error ({pop_err_string})')
        plt.ylabel(f'Max Group Error ({err_string})')

        # Plot the trajectories for the 'plot_index'-th error type over all gammas
        for single_run_bonus_plot_tuples, gamma, total_steps in zip(bonus_plot_list, gammas, total_steps_per_gamma):
            err_type, grp_errs, pop_errs, pop_err_type = single_run_bonus_plot_tuples[plot_index]
            x = pop_errs
            y = np.max(grp_errs, axis=1)
            plt.scatter(x, y, c=np.arange(1, total_steps), s=2)  # Plot the individual trajectory
            plt.scatter(x[0], y[0], c='m', s=20)  # Add magenta starting point
            plt.annotate(f'gamma={gamma:.5f}', xy=(x[-1], y[-1]))
            # Add the endpoints for the pareto curve
            endpoints_x.append(x[-1])
            endpoints_y.append(y[-1])

        # Compute and plot pareto curve
        pareto = get_pareto(endpoints_x, endpoints_y)
        if pareto is not None:
            plt.plot(pareto[:, 0], pareto[:, 1], 'r--', lw=2, label='Pareto Curve', alpha=0.5)
        plt.show()

    return figures, names