def do_learning()

in src/minmaxML.py [0:0]


def do_learning(X, y, numsteps, grouplabels, a=1, b=0.5,  equal_error=False, scale_eta_by_label_range=True,
                gamma=0.0, relaxed=False, rescale_features=True,
                model_type='LinearRegression', error_type='Total',
                extra_error_types=set(), pop_error_type='Total',
                convergence_threshold=1e-15,
                max_logi_iters=100, tol=1e-8, fit_intercept=True, logistic_solver='lbfgs', penalty='none', C=1e15,
                lr=0.01, momentum=0.9, weight_decay=0, n_epochs=10000, hidden_sizes=(2 / 3,),
                test_size=0.0, random_split_seed=0,
                group_names=(), group_types=(), data_name='',
                display_plots=True, verbose=False, use_input_commands=True,
                show_legend=True,
                save_models=False, save_plots=False, dirname=''):
    """
    :param X:  numpy matrix of features with dimensions numsamples x numdims
    :param y:  numpy array of labels with length numsamples. Should be numeric (0/1 labels binary classification)
    :param numsteps:  number of rounds to run the game
    :param a, b:  parameters for eta = a * t ^ (-b)
    :param scale_eta_by_label_range: if the inputted a value should be scaled by the max absolute label value squared
    :param rescale_features: Whether or not feature values should be rescaled for numerical stability
    :param grouplabels:  numpy array of numsamples numbers in the range [0, numgroups) denoting groups membership
    :param group_names:  list of groups names in relation to underlying data (e.g. [male, female])
    :param data_name:  name of the dataset being used to make plots clear
    :param gamma: maximum allowed max groups error by convergence
    :param relaxed: denotes whether or not we are solving the relaxed version of the problem
    :param model_type:  sklearn model type e.g. LinearRegression, LogisticRegression, etc.
    :param error_type:  for classification only! e.g. Total, FP, FN
    :param extra_error_types: set of error types which we want to plot
    :param pop_error_type: error type to use on population e.g. Total for FP/FN
    :param convergence_threshold: converges (early) when max change in sampleweights < convergence_threshold
    :param penalty: Regularization penalty for logistic regression
    :param C: inverse of regularization strength
    :param logistic_solver: Which underlying solver to use for logistic regression
    :param fit_intercept: Whether or not we should fit an additional intercept
    :param random_split_seed: the random state to perform the train test split on
    :param display_plots: denotes if plots should be displayed
    :param show_legend: denotes if plots should have legends with groups names
    :param save_models: denotes if models should be saved in each round (needed to extract mixtures)
    :param save_plots: determines if plots should be saved to a file
    :param dirname: name of directory to save plots/models in, if applicable (sub directory of s3 bucket, if applicable)
    :param test_size: if nonzero, proportion of data to be reserved for validation of training data
    :param max_logi_iters: max number of logistic regression iterations
    :param tol: tolerance of convergence for logistic regression
    :param lr: learning rate of gradient descent for MLP
    :param n_epochs: number of epochs per individual MLP model
    :param hidden_sizes: list of sizes for hidden layers of MLP - fractions (and 1) treated as proportions of numdims
    """

    if not use_input_commands and display_plots:
        warnings.warn('WARNING: use_input_commands is set to False. '
                      'This may cause plots to appear and immediately dissappear when running code from the command '
                      'line.')

    if save_models:
        warnings.warn('WARNING: save_models is set to True but this has no default functionality. This will write '
                      'every intermediate model (as a python object) to a file and may be memory intensive, '
                      'slowing down processing.')

    if relaxed and gamma == 0.0:
        warnings.warn('WARNING: Running a relaxed simulation with gamma = 0.0 which will likely be infeasible.'
                      '\nTo run an unrelaxed simulation, please set the relaxed flag to False.')

    if not (relaxed or equal_error) and error_type in ['FP', 'FN', 'FP-Log-Loss', 'FN-Log-Loss']:
        label_type = 'negative' if error_type.startswith('FP') else 'positive'
        warnings.warn(f'WARNING: Running an unconstrained simulation with {error_type} error type. \n'
                      f'In this setting, the minimax solution w.r.t. {error_type} is to always predict '
                      f'{label_type} labels, regardless of input.')

    if error_type.endswith('Log-Loss') and model_type not in ['LogisticRegression', 'MLPClassifier']:
        raise Exception('ERROR: Log-Loss error type is only supported for Logistic Regression.')

    if equal_error and model_type != 'PairedRegressionClassifier':
        warnings.warn('WARNING: Equal error rates is only supported for PairedRegressionClassifier '
                        f'due to negative weights. When using {model_type}, sample weights'
                        f'may be shifted upwards to avoid negative weights, changing the nature of the solution.')

    if equal_error and relaxed:
        raise Exception('Equal error is not supported for the relaxed algorithm.')

    # Rescales features to be within [-100, 100] for numerical stability
    if rescale_features:
        X = rescale_feature_matrix(X)

    # Divide eta (via scaling a) by the max label value squared. Equivalent to linearly scaling labels to range [-1, 1]
    if scale_eta_by_label_range:
        a /= max(abs(y)) ** 2

    # Hacky way to adjust for the fact that numsteps is 1 fewer than we want it to be because of 1 indexing
    numsteps += 1

    X = X.astype(np.float64)

    # Put our grouplabels list into two dimensions if only provided in one dimension
    if len(grouplabels.shape) < 2:
        n_samples = len(grouplabels)
        grouplabels = np.expand_dims(grouplabels, axis=0)
        group_names = np.expand_dims(group_names, axis=0)
        assert np.size(grouplabels[0]) == n_samples

    # Denotes whether or not each person belongs to multiple groups
    num_group_types = grouplabels.shape[0]

    # Array of 'numgroups' arrays, one for each groups category
    numgroups = np.array([np.size(np.unique(grouplabels[i])) for i in range(num_group_types)])

    # Modularizes the existing code to easily swap models with the argument "model_type"
    model_classes = {'LinearRegression': LinearRegression, 'LogisticRegression': LogisticRegression,
                     'PairedRegressionClassifier': PairedRegressionClassifier, 'Perceptron': Perceptron,
                     'MLPClassifier': MLPClassifier}

    regression_models = ['LinearRegression', 'MLPRegresor']
    classification_models = ['LogisticRegression', 'Perceptron', 'PairedRegressionClassifier',
                             'MLPClassifier']
    try:
        model_class = model_classes[model_type]
        if model_type in classification_models:
            if set(np.unique(y)) == {-1, 1}:  # Converts -1/1 labels into 0/1 labels
                y = (y >= 1)
            y = y.astype(int)  # Convert boolean array to 0-1 array to avoid errors
            try:
                assert set(np.unique(y)) == {0, 1}  # Ensure all labels are as expected
            except AssertionError:
                raise ValueError('Binary input labels y for classification must be encoded'
                                 ' as 0/1, -1/1, or True/False')
    except KeyError:
        raise Exception(f'Invalid model_type: {model_type}.')

    do_validation = (test_size != 0.0)  # Stores a boolean flag on whether or not we are doing validation
    if do_validation:
        # Use our custom function to create a balanced train/test split across groups membership
        # NOTE: If num_group_types > 1, this function will do a purely random split
        X_train, X_test, y_train, y_test, grouplabels_train, grouplabels_test = \
            create_validation_split(X, y, grouplabels, test_size, random_seed=random_split_seed)
    else:
        # If we aren't doing a split, all data is used as "training" data
        X_train, y_train, grouplabels_train, = X, y, grouplabels
        X_test, y_test, grouplabels_test = None, None, None

    # Compute features about the data
    numsamples, numdims = X_train.shape
    if do_validation:
        val_numsamples, _ = X_test.shape

    # Setup arrays storing the indices of each individual groups for ease of use later
    groupsize = [np.array([]) for _ in range(num_group_types)]
    index = [[] for _ in range(num_group_types)]  # List of lists of slices
    for i in range(num_group_types):
        # index[i] is a "list" of length numgroups[i] whose elements are slices for (np.where label == g)
        groupsize[i], index[i] = create_index_array(numgroups[i], grouplabels_train[i])
    # Repeat the above for valiadtion
    if do_validation:
        val_groupsize = [np.array([]) for _ in range(num_group_types)]
        val_index = [[] for _ in range(num_group_types)]
        for i in range(num_group_types):
            val_groupsize[i], val_index[i] = create_index_array(numgroups[i], grouplabels_test[i])

    # Instatiate all error arrays
    errors = np.zeros((numsteps, numsamples))  # Stores error for each member of pop for each round
    # Store errors for each groups over rounds both for individual model and aggregate mixture
    grouperrs, agg_grouperrs = create_group_error_arrays(num_group_types, numsteps, numgroups)
    if do_validation:
        val_errors = np.zeros((numsteps, val_numsamples))
        val_grouperrs, val_agg_grouperrs = create_group_error_arrays(num_group_types, numsteps, numgroups)

    # In the case that total error is not the same as the specific error (e.g. FP, FN) for classification, we store both
    if model_type in classification_models:
        groupsize_pos = [np.array([]) for _ in range(num_group_types)]
        groupsize_neg = [np.array([]) for _ in range(num_group_types)]
        # Index is a list over groups types where each element is a list of numpy slices, one list for each subgroup
        # Used for reweighting samples based on groups weights for each groups they are a member of
        index_pos = [[] for _ in range(num_group_types)]
        index_neg = [[] for _ in range(num_group_types)]
        for i in range(num_group_types):
            # Compute the subgroups for positive and negative classes
            groupsize_pos[i], groupsize_neg[i], index_pos[i], index_neg[i] = \
                setup_pos_neg_group_arrays(numgroups[i], index[i], y_train)

        if do_validation:  # Repeat the above for validation
            val_groupsize_pos = [np.array([]) for _ in range(num_group_types)]
            val_groupsize_neg = [np.array([]) for _ in range(num_group_types)]
            for i in range(num_group_types):
                val_groupsize_pos[i], val_groupsize_neg[i], _, _ = \
                    setup_pos_neg_group_arrays(numgroups[i], index[i], y_test)

        if error_type not in ['0/1 Loss', 'Log-Loss', 'FP', 'FN', 'FP-Log-Loss', 'FN-Log-Loss']:
            raise Exception(f"ERROR: Unsupported classification error type: {error_type}")
    else:
        # NOTE: Currently, the only supported error type for regression is MSE.
        if error_type != 'MSE':
            warnings.warn(f'WARNING: Error type {error_type} is not usable in regression settings. \
            Automatically changing error type to `MSE` and continuing...')
        error_type = pop_error_type = 'MSE'  # Rename the 'total' error type to MSE in regression case

        if equal_error:
            warnings.warn('WARNING: Equal error is not supported for regression. Returning minmax solution instead.')
            equal_error = False

    # Dictionary of arrays storing the errors of each type, can use other functions to compute over rounds
    # Instantiate dictionaries with the error type we are reweighting on
    specific_errors = {error_type: errors}
    if do_validation:
        val_specific_errors = {error_type: val_errors}

    # Converting empty dictionary to set makes it easier to use set literals in main_drivers
    if extra_error_types == {}:
        extra_error_types = set()

    # Ensure we do not duplicate/overwrite the main error type as an extra error type
    if error_type in extra_error_types:
        extra_error_types.remove(error_type)

    if model_type in regression_models:  # If we have a regression model, no additional error types make sense
        extra_error_types = set()
    elif error_type in ['FP', 'FN']:  # If we are in FP/FN setting, then we always need to use Total error for pop
        extra_error_types.add('0/1 Loss')
        pop_error_type = '0/1 Loss'
    elif error_type in ['FP-Log-Loss', 'FN-Log-Loss']:
        # extra_error_types.add('Log-Loss')
        extra_error_types.add('0/1 Loss')
        extra_error_types.add(error_type.split('-')[0])  # Adds FP/FN for FP-Log-Loss/FN-Log-Loss
        pop_error_type = '0/1 Loss'

    # If pop_error_type is unspecified and not caught in the above cases, let it be the regular error type
    if pop_error_type == '':
        pop_error_type = error_type

    # Create a new array to store the errors of each type
    for extra_err_type in extra_error_types:
        specific_errors[extra_err_type] = np.zeros((numsteps, numsamples))
        if do_validation:
            val_specific_errors[extra_err_type] = np.zeros((numsteps, val_numsamples))

    # Assign the correct groupsize to the error type as a separate array for reweighting purposes
    if error_type in ['0/1 Loss', 'MSE', 'Log-Loss']:
        groupsize_err_type = groupsize
        if do_validation:
            val_groupsize_err_type = val_groupsize
    elif error_type.startswith('FP'):
        groupsize_err_type = groupsize_neg
        if do_validation:
            val_groupsize_err_type = val_groupsize_neg
    elif error_type.startswith('FN'):
        groupsize_err_type = groupsize_pos
        if do_validation:
            val_groupsize_err_type = val_groupsize_pos
    else:
        raise Exception(f'Invalid Error Type {error_type}')

    if verbose:
        # print(f'Group labels are: {grouplabels}')
        print(f'Group names are:   {group_names}')
        print(f'Group types are:   {group_types}')
        if do_validation:
            print(f'Group sizes (train): {groupsize}')
            print(f'Group sizes (val):   {val_groupsize}')
        else:
            print(f'Group sizes are: {groupsize}')

    # Initialize sample weights and groups weights for Regulator
    groupweights = [np.zeros((numsteps, numgroups[i])) for i in range(num_group_types)]
    p = [np.array([]) for _ in range(num_group_types)]
    sampleweights = [np.array([]) for _ in range(num_group_types)]
    prev_avg_sampleweights = np.zeros(numsamples)  # Will store previous round sampleweights for convergence detection
    # Fill the weight arrays as necessary
    for i in range(num_group_types):
        p[i] = groupsize[i] / numsamples  # Compute each groups proportion of the population
        groupweights[i][0] = p[i]
        sampleweights[i] = np.ones(numsamples) / numsamples  # Initialize sample weights array to uniform
    # Convert sampleweights to numpy array since it's rectangular,
    sampleweights = np.array(sampleweights)

    # Instantiate lambas for relaxed or equal errors
    if relaxed or equal_error:
        lambdas = [np.zeros((numsteps, numgroups[i])) for i in range(num_group_types)]

    # List for storing the model produced at every round if applicable
    modelhats = []

    if verbose:
        print(f'Starting simulation with the following paramters: \n' +
              f'model_type: {model_type} \n' +
              f'numsamples: {numsamples} \n' +
              f'numdims: {numdims} \n' +
              f'numgroups: {numgroups} \n' +
              f'numsteps: {numsteps - 1} \n' +
              f'a: {a} \n' +
              f'b: {b} \n')
        if model_type == 'LogisticRegression':
            print('fit_intercept:', fit_intercept)
            print('solver', logistic_solver)
            print('max_iterations:', max_logi_iters)
            print('tol:', tol)

    # Will store the total number of steps that actually occur. Updated to less than numsteps for early convergence
    total_steps = numsteps
    # Simulating game steps
    for t in range(1, numsteps):

        avg_sampleweights = np.squeeze(np.sum(sampleweights, axis=0) / num_group_types)

        # Converge if max change in sampleweights is less than convergence threshold
        if t > 3 and max(abs(avg_sampleweights - prev_avg_sampleweights)) < convergence_threshold:
            total_steps = t
            print(f'Converging early at round {total_steps}')
            print(max(abs(avg_sampleweights - prev_avg_sampleweights)))
            break

        if t % max(1, (numsteps // 50)) == 0 and verbose:
            print(f'starting round {t}...')

        # Set the eta value for this round
        eta = a * t ** (-b)

        # Learner best responds to current weight by training a model on weighted sample points
        if model_type == 'LogisticRegression':
            with warnings.catch_warnings():
                warnings.filterwarnings('error', category=ConvergenceWarning)  # Cause Convergence warnings to be error
                try:
                    modelhat = \
                        model_class(max_iter=max_logi_iters, tol=tol, fit_intercept=fit_intercept,
                                    solver=logistic_solver, penalty=penalty, C=C,
                                    verbose=0).fit(X_train, y_train, avg_sampleweights)
                except Warning:
                    raise Exception(f'Logistic regression did not converge with {max_logi_iters} iterations.')
        elif model_type == 'PairedRegressionClassifier':
            # NOTE: This is not an sklearn model_class, but a custom class
            modelhat = model_class(regressor_class=LinearRegression).fit(X_train, y_train, avg_sampleweights)
        elif model_type == 'MLPClassifier':  # Pytorch's MLP wrapped with our custom class to work with the interface
            hidden_sizes = [numdims] + \
                           list(map(lambda x: x if np.floor(x) == x else int(np.floor(x * numdims)), hidden_sizes))
            modelhat = MLPClassifier(hidden_sizes, lr=lr, momentum=momentum, weight_decay=weight_decay). \
                fit(X_train, y_train, avg_sampleweights, n_epochs=n_epochs)
        else:  # Linear Regression or Perceptron
            modelhat = model_class().fit(X_train, y_train, sample_weight=avg_sampleweights)

        # Set values of prev_sampleweights to equal to the current values of sampleweights
        prev_avg_sampleweights = avg_sampleweights.copy()

        # Store each rounds model as necessary -- CURRENTLY UNUSED
        if save_models:
            modelhats.append(modelhat)  # Append the model as a python object to the list

        # Compute the errors of the model according to the specified loss function

        if model_type in regression_models:
            # Updates errors array with the round-specific errors for each person for round t
            compute_model_errors(modelhat, X_train, y_train, t, errors, 'MSE')
            if do_validation:
                compute_model_errors(modelhat, X_test, y_test, t, val_errors, 'MSE')
            # NOTE: Currently, there are no "extra_error_types" feasible for regression

        elif model_type in classification_models:
            # Updates errors array with the round-specific errors for each person for round t
            compute_model_errors(modelhat, X_train, y_train, t, errors, error_type, penalty, C)
            # Compute the errors for all additional error types
            for err_type in extra_error_types:
                compute_model_errors(modelhat, X_train, y_train, t, specific_errors[err_type], err_type, penalty, C)
            # Repeat for validation
            if do_validation:
                compute_model_errors(modelhat, X_test, y_test, t, val_errors, error_type, penalty, C)
                for err_type in extra_error_types:
                    compute_model_errors(modelhat, X_test, y_test, t, val_specific_errors[err_type], err_type,
                                         penalty, C)
        else:
            raise Exception(f'Invalid Model Type: {model_type}')

        # Compute groups error rates for each groups this round across each type of groups
        for i in range(num_group_types):
            update_group_errors(numgroups[i], t, errors, grouperrs[i], agg_grouperrs[i], index[i],
                                groupsize_err_type[i])
            if do_validation:
                update_group_errors(numgroups[i], t, val_errors, val_grouperrs[i], val_agg_grouperrs[i],
                                    val_index[i], val_groupsize_err_type[i])

            # Weight update type depends on relaxed or not
            if relaxed:  # Projected Gradient descent
                lambdas[i][t] = np.maximum(0, lambdas[i][t - 1] + (eta * (grouperrs[i][t] - gamma)))
                groupweights[i][t] = p[i] + lambdas[i][t]
            else:  # Non-relaxed
                if equal_error:  # GD where errors are pushed to mean error
                    poperrs = errors[t]  # if model_type in regression_models else specific_errors['Total'][t]
                    mean_poperrs = sum(poperrs) / numsamples
                    lambdas[i][t] = lambdas[i][t - 1] + (eta * (grouperrs[i][t] - mean_poperrs))
                    groupweights[i][t] = p[i] * (1 - np.sum(lambdas[i][t])) + lambdas[i][t]
                else:  # Exponential Weights, Minmax algorithm
                    groupweights[i][t] = np.multiply(groupweights[i][t - 1], np.exp(eta * (grouperrs[i][t])))
                    groupweights[i][t] = groupweights[i][t] / np.sum(groupweights[i][t])  # normalize

            # Translate groups weights to sample weights for Learner
            for g in range(0, numgroups[i]):
                if (relaxed or equal_error) and error_type in ['FP', 'FN', 'FP-Log-Loss', 'FN-Log-Loss']:
                    baseline_weight = 1 / numsamples  # if not use_obj_as_constraint else lambda_obj[i] / numsamples
                    weight_per_neg_sample = baseline_weight
                    weight_per_pos_sample = baseline_weight
                    if error_type.startswith('FP'):
                        # When we make FP errors, we are making mistakes on negative sample points, so we upweight
                        if groupsize_neg[i][g] > 0:
                            weight_per_neg_sample = baseline_weight + lambdas[i][t, g] / groupsize_neg[i][g]
                    else:  # 'error_type' == 'FN'
                        if groupsize_pos[i][g] > 0:
                            weight_per_pos_sample = baseline_weight + lambdas[i][t, g] / groupsize_pos[i][g]
                    # Set the sample weights of the subgroups
                    sampleweights[i, index_neg[i][g]] = weight_per_neg_sample
                    sampleweights[i, index_pos[i][g]] = weight_per_pos_sample
                # Error type = total, or not relaxed
                else:
                    weight_per_sample = groupweights[i][t, g] / groupsize[i][g]
                    sampleweights[i, index[i][g]] = weight_per_sample

            # If negative sampleweights (with equal error), shift all weights up by the min weight to ensure positivity
            if equal_error and model_type != 'PairedRegressionClassifier':
                weight_to_add = -np.min(sampleweights)
                if weight_to_add > 0:
                    sampleweights[i] = sampleweights[i] + weight_to_add  # np.maximum(0, -np.min(sampleweights))

    # Game is finished here

    # Truncate the groups error arrays to have length equal to the number of rounds actually performed
    # Remove 0th position in the arrays which stored the value 0 for easy DP
    agg_grouperrs = [arr[1:total_steps, :] for arr in agg_grouperrs]
    if do_validation:
        val_agg_grouperrs = [arr[1:total_steps, :] for arr in val_agg_grouperrs]

    # Computes the expected error of the mixture with respect to the population with DP style updates at each round
    agg_poperrs = compute_mixture_pop_errors(specific_errors[pop_error_type], total_steps)
    if do_validation:
        val_agg_poperrs = compute_mixture_pop_errors(val_specific_errors[pop_error_type], total_steps)

    # Plot and save results as necessary
    if display_plots or save_plots:
        loss_string = ''
        if error_type in ['FP', 'FN']:
            loss_string = f' weighted on {error_type} Loss'
        elif error_type.endswith('Log-Loss'):
            loss_string = f' weighted on {error_type}'
        elif error_type == '0/1 Loss':
            if model_type in classification_models:
                loss_string = f' weighted on 0/1 Loss'

        model_string = f'\n {model_type}' + loss_string + (' for Equal-Error' if equal_error else '') + \
                       (f'\n Gamma={gamma}' if relaxed else '')

        # Combine groups names with their sizes for display when plotting
        group_names_and_sizes_list = get_group_names_and_sizes_list(group_names, groupsize, num_group_types)
        # Generate bonus plots as necessary
        stacked_bonus_plots = create_stacked_bonus_plots(num_group_types, extra_error_types, numgroups,
                                                         specific_errors, index, groupsize, total_steps)
        # Do a final combined plot of everything
        do_plotting(display_plots, save_plots, use_input_commands, total_steps, group_names_and_sizes_list,
                    group_types,
                    show_legend, error_type, data_name, model_string,
                    agg_poperrs, agg_grouperrs, groupweights,
                    pop_error_type, stacked_bonus_plots,
                    dirname,
                    multi_group=True)

        # Repeat for validation as necessary
        if do_validation:
            val_group_names_and_sizes_list = get_group_names_and_sizes_list(group_names, val_groupsize,
                                                                            num_group_types)
            val_stacked_bonus_plots = create_stacked_bonus_plots(num_group_types, extra_error_types, numgroups,
                                                                 val_specific_errors, val_index, val_groupsize,
                                                                 total_steps)
            do_plotting(display_plots, save_plots, use_input_commands, total_steps, val_group_names_and_sizes_list,
                        group_types,
                        show_legend, error_type, data_name, model_string + f'\n Validation with test_size={test_size}',
                        val_agg_poperrs, val_agg_grouperrs, None,
                        pop_error_type, val_stacked_bonus_plots,
                        dirname, validation=True, multi_group=True)
    else:  # Ensures that return doesn't fail when we aren't plotting
        stacked_bonus_plots = None
        if do_validation:
            val_stacked_bonus_plots = None

    # Save models as pythonic objects to either filesystem/S3 bucket
    if save_models:
        save_models_to_os(modelhats, dirname)

    final_max_group_error = [-1 for _ in range(num_group_types)]
    highest_gamma = [-1 for _ in range(num_group_types)]
    for i in range(num_group_types):
        final_max_group_error[i] = np.max(agg_grouperrs[i][-1])  # This is the minimum gamma we think is feasible
        highest_gamma[i] = compute_highest_gamma(agg_poperrs, agg_grouperrs[i], relaxed)

    if do_validation:
        val_max_grp_err = np.max(val_agg_grouperrs[:][-1])  # This is the minimum gamma we think is feasible
        val_pop_err = val_agg_poperrs[-1]  # max groups error when optimizing for pop error

        # For now, we are taking the max over all types of groups
        return (max(final_max_group_error), max(highest_gamma), agg_poperrs[1], agg_grouperrs, agg_poperrs,
                stacked_bonus_plots, pop_error_type, total_steps, modelhats,
                val_max_grp_err, val_pop_err, val_agg_grouperrs, val_agg_poperrs, val_stacked_bonus_plots)
    else:
        if relaxed:
            margin_of_error = 0.001  # Allows for a small margin of error for 'unfeasible' gammas
            if gamma + margin_of_error < np.max((agg_grouperrs[:][-1])):
                warnings.warn(f'WARNING: Desired gamma value may not be feasible with margin of error: '
                              f'{margin_of_error}. \n'
                              f'Gamma = {gamma} but the mixture\'s max groups error was {final_max_group_error}')

        return (max(final_max_group_error), max(highest_gamma), agg_poperrs[1], agg_grouperrs, agg_poperrs,
                stacked_bonus_plots, pop_error_type, total_steps, modelhats,
                None, None, None, None, None)