def do_pareto_plot()

in src/plot_relaxed_pareto.py [0:0]


def do_pareto_plot(gammas, total_steps_per_gamma, max_grp_errs, pop_errs, trajectories, numsteps,
                   error_type, pop_error_type,
                   save_plots, dirname,
                   model_type,
                   use_input_commands,
                   data_name='', bonus_plot_list=None, show_basic_plots=False,
                   val_max_grp_errs=None, val_pop_errs=None, val_trajectories=None, val_bonus_plot_list=None,
                   test_size=0.0):
    """
    Utility function used in main_driver to create a multi-trajectory plot over runs with a range of gamma values,
    and traces the pareto curve of errors resulting mixture models.

    Use argument `show_basic_plots` to enable scatter plots for pairwise relationships between population error,
    max groups error, and gamma, of the final mixture models.
    """

    figures = []
    plt.ion()

    # Setup strings for graph titles
    dataset_string = f' on {data_name[0].upper() + data_name[1:]}' if data_name != '' else ''

    # Get the pareto curve
    pareto = get_pareto(pop_errs, max_grp_errs)

    # Set pop_error string
    pop_error_string = pop_error_type
    if pop_error_type == 'Total':
        pop_error_string = f'0/1 Loss'

    if show_basic_plots:
        if use_input_commands:
            input('Press `Enter` to display first plot...')
        figures.append(plt.figure())
        plt.scatter(pop_errs, max_grp_errs)
        plt.title(f'Pop Error vs. Max Group Error{dataset_string} \n {model_type}')
        plt.xlabel(f'Pop Error ({pop_error_string})')
        plt.ylabel(f'Max Group Error ({error_type})')
        # Compute and plot pareto curve
        if pareto is not None:
            plt.plot(pareto[:, 0], pareto[:, 1], 'r--', lw=2, label='Pareto Curve', alpha=0.5)
        plt.show()

        if use_input_commands:
            input('Next plot...')
        figures.append(plt.figure())
        plt.scatter(gammas, max_grp_errs)
        plt.title(f'Gamma vs. Max Group Error{dataset_string}   \n {model_type}')
        plt.xlabel('Gamma')
        plt.ylabel(f'Max Group Error ({error_type})')
        plt.show()

        if use_input_commands:
            input('Next plot...')
        figures.append(plt.figure())
        plt.scatter(gammas, pop_errs)
        plt.title(f'Gamma vs. Pop Error{dataset_string}   \n {model_type}')
        plt.xlabel('Gamma')
        plt.ylabel(f'Pop Error ({pop_error_string})')
        plt.show()

    # Multi-trajectory plot
    if use_input_commands:
        input('Next plot...')
    figures.append(plt.figure())
    colors = [np.arange(1, total_steps) for total_steps in total_steps_per_gamma]
    for (x, y), gamma, color in zip(trajectories, gammas, colors):
        plt.scatter(x, y, c=color, s=2)
        plt.scatter(x[0], y[0], c='m', s=20)
        plt.annotate(f'gamma={gamma:.5f}', xy=(x[-1], y[-1]))
    plt.title(f'Trajectories over {numsteps} Rounds{dataset_string}  \n {model_type}')
    plt.xlabel(f'Pop Error ({pop_error_string})')
    plt.ylabel(f'Max Group Error ({error_type})')
    # Add the pareto plot here as well
    if pareto is not None:
        plt.plot(pareto[:, 0], pareto[:, 1], 'r--', lw=2, label='Pareto Curve', alpha=0.5)
    plt.show()

    if show_basic_plots:
        figure_names = ['PopError_vs_MaxGroupError', 'Gamma_vs_MaxGroupError', 'Gamma_vs_PopError',
                        'Trajectories_over_Gammas']
    else:
        figure_names = ['Trajectories_over_Gamma']

    # Do the multi-trajectory plots for the additional error types
    colors = [np.arange(1, total_steps) for total_steps in total_steps_per_gamma]
    bonus_figures, bonus_names = \
        plot_trajectories_from_bonus_plot_data(bonus_plot_list, gammas, model_type, error_type, numsteps,
                                               total_steps_per_gamma,
                                               use_input_commands)
    figures.extend(bonus_figures)
    figure_names.extend(bonus_names)

    if val_max_grp_errs is not None and val_pop_errs is not None:

        val_pareto = get_pareto(val_pop_errs, val_max_grp_errs)

        if show_basic_plots:
            # Validation Pop Error vs. Max Group Error
            if use_input_commands:
                input('Click enter to display first validation plot')
            figures.append(plt.figure())
            plt.scatter(val_pop_errs, val_max_grp_errs)
            plt.title(f'Pop Error vs. Max Group Error{dataset_string} (Validation: {test_size}) \n {model_type}')
            plt.xlabel(f'Pop Error ({pop_error_string})')
            plt.ylabel(f'Max Group Error ({error_type})')
            # Compute and plot pareto curve
            if val_pareto is not None:
                plt.plot(val_pareto[:, 0], val_pareto[:, 1], 'r--', lw=2, label='Pareto Curve', alpha=0.5)
            plt.show()

            # Validation Gamma vs. Max Group Error
            if use_input_commands:
                input('Next plot...')
            figures.append(plt.figure())
            plt.scatter(gammas, val_max_grp_errs)
            plt.title(f'Gamma vs. Max Group Error{dataset_string} (Validation: {test_size}) \n {model_type}')
            plt.xlabel('Gamma')
            plt.ylabel(f'Max Group Error ({error_type})')
            plt.show()

            # Validation Gamma vs. Pop Error
            if use_input_commands:
                input('Next plot...')
            figures.append(plt.figure())
            plt.scatter(gammas, val_pop_errs)
            plt.title(f'Gamma vs. Pop Error{dataset_string} (Validation: {test_size}) \n {model_type}')
            plt.xlabel('Gamma')
            plt.ylabel(f'Pop Error ({pop_error_string})')
            plt.show()

        # Validation Trajectory
        if use_input_commands:
            input('Next plot...')
        figures.append(plt.figure())
        colors = [np.arange(1, total_steps) for total_steps in total_steps_per_gamma]
        for (x, y), gamma, color in zip(val_trajectories, gammas, colors):
            plt.scatter(x, y, c=color, s=2)
            plt.annotate(f'gamma={gamma:.5f}', xy=(x[-1], y[-1]))
        plt.title(f'Trajectories over {numsteps} Rounds{dataset_string}  (Validation: {test_size}) \n {model_type}')
        plt.xlabel(f'Pop Error ({pop_error_string})')
        plt.ylabel(f'Max Group Error ({error_type})')
        if val_pareto is not None:
            plt.plot(val_pareto[:, 0], val_pareto[:, 1], 'r--', lw=2, label='Pareto Curve', alpha=0.5)
        plt.show()

        if show_basic_plots:
            figure_names.extend(['PopError_vs_MaxGroupError_Validation', 'Gamma_vs_MaxGroupError_Validation',
                                 'Gamma_vs_PopError_Validation', 'Trajectories_over_Gammas_Validation'])
        else:
            figure_names.extend(['Trajectories_over_Gamma_Validation'])

        # colors = [np.arange(1, total_steps) for total_steps in total_steps_per_gamma]
        val_bonus_figures, val_bonus_names = \
            plot_trajectories_from_bonus_plot_data(val_bonus_plot_list, gammas, model_type, error_type,
                                                   numsteps, total_steps_per_gamma, use_input_commands,
                                                   test_size=test_size)
        figures.extend(val_bonus_figures)
        val_bonus_names = [name + '_Validation' for name in val_bonus_names]
        figure_names.extend(val_bonus_names)

    if use_input_commands:
        input('Quit')

    if save_plots:
        save_plots_to_os(figures, figure_names, dirname, True)
        plt.close('all')