def do_plotting()

in src/plotting.py [0:0]


def do_plotting(display_plots, save_plots, use_input_commands, numsteps, group_names, group_type,
                show_legend, error_type, data_name, model_string,
                agg_poperrs, agg_grouperrs, groupweights, pop_error_type, bonus_plots,
                dirname, multi_group=False,
                validation=False, equal_error=False):
    """
    Helper function for minimaxML that creates the relevant plots for a single run of the simulation.
    """

    # Create a list of all figures we want to save for later which will be passed into a function
    figures = []
    figure_names = ['PopError_vs_Rounds', 'GroupError_vs_Rounds', 'GroupWeights_vs_Rounds', 'Trajectory_Plot']

    # Combine all the existing arrays as necessary by separating all subgroups as unqiue groups
    if multi_group:
        num_group_types = len(agg_grouperrs)  # list of numpy arrays
        agg_grouperrs = np.column_stack(agg_grouperrs)  # vertically stck the groups errs
        if not validation:
            groupweights = np.column_stack(groupweights)  # vertically stack the weights for each groups
        stacked_group_names = []  # stack the groups errors
        for i in range(num_group_types):
            g_type = group_type[i] if num_group_types > 1 else ''
            stacked_group_names.extend([g_type + ': ' + name for name in group_names[i]])

        group_names = stacked_group_names

    # End of multi-groups adjustments

    if group_type is not None:
        # print(f'Here are the plots for groups based on: {group_type}')
        pass

    if use_input_commands and display_plots:
        input("Press `Enter` to show first plot... ")

    # Setup strings for graph titles
    dataset_string = f' on {data_name[0].upper() + data_name[1:]}'  # Set the first letter to capital if it isn't

    plt.ion()
    # Average Pop error vs. Rounds
    figures.append(plt.figure())  # Creates figure and adds it to list of figures
    plt.plot(agg_poperrs)
    plt.title(f'Average Population Error ({pop_error_type})' + dataset_string + model_string)
    plt.xlabel('Steps')
    plt.ylabel(f'Average Population Error ({pop_error_type})')
    if display_plots:
        plt.show()

    if use_input_commands and display_plots:
        input("Next plot...")

    # Group Errors vs. Rounds
    figures.append(plt.figure())  # Create figure and append to list
    for g in range(0, len(group_names)):
        # Plots the groups with appropriate label
        plt.plot(agg_grouperrs[:, g], label=group_names[g])
    if show_legend:
        plt.legend(loc='upper right')
    plt.title(f'Group Errors ({error_type})' + dataset_string + model_string)
    plt.xlabel('Steps')
    plt.ylabel(f'Group Errors ({error_type})')
    if display_plots:
        plt.show()

    if use_input_commands and display_plots:
        input("Next plot...")

    # Group Weights vs. Rounds
    if not validation and groupweights is not None:  # Groupweights aren't a part of validation
        figures.append(plt.figure())  # Create figure and append to list
        for g in range(0, len(group_names)):
            plt.plot(groupweights[:, g], label=group_names[g])
        if show_legend:
            plt.legend(loc='upper right')
        plt.title(f'Group Weights' + dataset_string + model_string)
        plt.xlabel('Steps')
        plt.ylabel('Group Weights')
        if display_plots:
            plt.show()

        if use_input_commands and display_plots:
            input("Next plot...")

    # Trajectory Plot with Pareto Curve
    figures.append(plt.figure())
    x = agg_poperrs
    y = np.max(agg_grouperrs, axis=1)
    points = np.zeros((len(x), 2))
    points[:, 0] = x
    points[:, 1] = y

    colors = np.arange(1, numsteps)
    plt.scatter(x, y, c=colors, s=2, label='Trajectory of Mixtures')
    plt.scatter(x[0], y[0], c='m', s=40, label='Starting point')  # Make the first point big and pink
    plt.title(f'Trajectory over {numsteps - 1} rounds' + dataset_string + model_string)
    plt.xlabel(f'Population Error ({pop_error_type})')
    plt.ylabel(f'Max Group Error ({error_type})')

    if display_plots:
        plt.show()

    for err_type, grp_errs, pop_errs, pop_err_type in bonus_plots:
        if use_input_commands and display_plots:
            input(f"Next bonus plot for error type {err_type}...")

        # Group Errors vs. Rounds
        figures.append(plt.figure())  # Create figure and append to list
        figure_names.append(f'GroupError_vs_Rounds_({err_type if err_type != "0/1 Loss" else "0-1 Loss"})')
        for g in range(0, len(group_names)):
            # Plots the groups with appropriate label
            plt.plot(grp_errs[:, g], label=group_names[g])
        if show_legend:
            plt.legend(loc='upper right')
        plt.title(f'Group Errors ({err_type})' + dataset_string + model_string)
        plt.xlabel('Steps')
        plt.ylabel(f'Group Errors ({err_type})')
        if display_plots:
            plt.show()

        if use_input_commands and display_plots:
            input("Next bonus plot (trajectory)...")

        figures.append(plt.figure())
        figure_names.append(f'Trajectory_({err_type if (err_type != "0/1 Loss") else "0-1 Loss"})')
        x = pop_errs
        y = np.max(grp_errs, axis=1)
        points = np.zeros((len(x), 2))
        points[:, 0] = x
        points[:, 1] = y

        colors = np.arange(1, numsteps)
        plt.scatter(x, y, c=colors, s=2, label='Trajectory of Mixtures')
        plt.scatter(x[0], y[0], c='m', s=40, label='Starting point')  # Make the first point big and pink
        plt.title(f'Trajectory over {numsteps - 1} rounds' + dataset_string + model_string)
        plt.xlabel(f'Population Error ({err_type})')
        plt.ylabel(f'Max Group Error ({pop_err_type})')

        if display_plots:
            plt.show()

    if use_input_commands and display_plots:
        input("Quit")

    # Update the names if doing valiadtion
    if validation:
        figure_names = [name + '_Validation' for name in figure_names]

    # Now we have a list of plots: `figures` we can save
    if save_plots:
        save_plots_to_os(figures, figure_names, dirname)
        plt.close('all')