def distill()

in core/src/autogluon/core/trainer/abstract_trainer.py [0:0]


    def distill(self, X=None, y=None, X_val=None, y_val=None, X_unlabeled=None,
                time_limit=None, hyperparameters=None, holdout_frac=None, verbosity=None,
                models_name_suffix=None, teacher=None, teacher_preds='soft',
                augmentation_data=None, augment_method='spunge', augment_args={'size_factor':5,'max_size':int(1e5)},
                augmented_sample_weight=1.0):
        """ Various distillation algorithms.
            Args:
                X, y: pd.DataFrame and pd.Series of training data.
                    If None, original training data used during predictor.fit() will be loaded.
                    This data is split into train/validation if X_val, y_val are None.
                X_val, y_val: pd.DataFrame and pd.Series of validation data.
                time_limit, hyperparameters, holdout_frac: defined as in predictor.fit()
                teacher (None or str):
                    If None, uses the model with the highest validation score as the teacher model, otherwise use the specified model name as the teacher.
                teacher_preds (None or str): If None, we only train with original labels (no data augmentation, overrides augment_method)
                    If 'hard', labels are hard teacher predictions given by: teacher.predict()
                    If 'soft', labels are soft teacher predictions given by: teacher.predict_proba()
                    Note: 'hard' and 'soft' are equivalent for regression problems.
                    If augment_method specified, teacher predictions are only used to label augmented data (training data keeps original labels).
                    To apply label-smoothing: teacher_preds='onehot' will use original training data labels converted to one-hots for multiclass (no data augmentation).  # TODO: expose smoothing-hyperparameter.
                models_name_suffix (str): Suffix to append to each student model's name, new names will look like: 'MODELNAME_dstl_SUFFIX'
                augmentation_data: pd.DataFrame of additional data to use as "augmented data" (does not contain labels).
                    When specified, augment_method, augment_args are ignored, and this is the only augmented data that is used (teacher_preds cannot be None).
                augment_method (None or str): specifies which augmentation strategy to utilize. Options: [None, 'spunge','munge']
                    If None, no augmentation gets applied.
                }
                augment_args (dict): args passed into the augmentation function corresponding to augment_method.
                augmented_sample_weight (float): Nonnegative value indicating how much to weight augmented samples. This is only considered if sample_weight was initially specified in Predictor.
        """
        if verbosity is None:
            verbosity = self.verbosity

        if teacher is None:
            teacher = self._get_best()

        hyperparameter_tune = False  # TODO: add as argument with scheduler options.
        if augmentation_data is not None and teacher_preds is None:
            raise ValueError("augmentation_data must be None if teacher_preds is None")

        logger.log(20, f"Distilling with teacher='{teacher}', teacher_preds={str(teacher_preds)}, augment_method={str(augment_method)} ...")
        if teacher not in self.get_model_names(can_infer=True):
            raise AssertionError(f"Teacher model '{teacher}' is not a valid teacher model! Either it does not exist or it cannot infer on new data.\n"
                                 f"Valid teacher models: {self.get_model_names(can_infer=True)}")
        if X is None:
            if y is not None:
                raise ValueError("X cannot be None when y specified.")
            X = self.load_X()
            X_val = self.load_X_val()

        if y is None:
            y = self.load_y()
            y_val = self.load_y_val()

        if X_val is None:
            if y_val is not None:
                raise ValueError("X_val cannot be None when y_val specified.")
            if holdout_frac is None:
                holdout_frac = default_holdout_frac(len(X), hyperparameter_tune)
            X, X_val, y, y_val = generate_train_test_split(X, y, problem_type=self.problem_type, test_size=holdout_frac)

        y_val_og = y_val.copy()
        og_bagged_mode = self.bagged_mode
        og_verbosity = self.verbosity
        self.bagged_mode = False  # turn off bagging
        self.verbosity = verbosity  # change verbosity for distillation

        if self.sample_weight is not None:
            X, w = extract_column(X, self.sample_weight)

        if teacher_preds is None or teacher_preds == 'onehot':
            augment_method = None
            logger.log(20, "Training students without a teacher model. Set teacher_preds = 'soft' or 'hard' to distill using the best AutoGluon predictor as teacher.")

        if teacher_preds in ['onehot','soft']:
            y = format_distillation_labels(y, self.problem_type, self.num_classes)
            y_val = format_distillation_labels(y_val, self.problem_type, self.num_classes)

        if augment_method is None and augmentation_data is None:
            if teacher_preds == 'hard':
                y_pred = pd.Series(self.predict(X, model=teacher))
                if (self.problem_type != REGRESSION) and (len(y_pred.unique()) < len(y.unique())):  # add missing labels
                    logger.log(15, "Adding missing labels to distillation dataset by including some real training examples")
                    indices_to_add = []
                    for clss in y.unique():
                        if clss not in y_pred.unique():
                            logger.log(15, f"Fetching a row with label={clss} from training data")
                            clss_index = y[y == clss].index[0]
                            indices_to_add.append(clss_index)
                    X_extra = X.loc[indices_to_add].copy()
                    y_extra = y.loc[indices_to_add].copy()  # these are actually real training examples
                    X = pd.concat([X, X_extra])
                    y_pred = pd.concat([y_pred, y_extra])
                    if self.sample_weight is not None:
                        w = pd.concat([w, w[indices_to_add]])
                y = y_pred
            elif teacher_preds == 'soft':
                y = self.predict_proba(X, model=teacher)
                if self.problem_type == MULTICLASS:
                    y = pd.DataFrame(y)
                else:
                    y = pd.Series(y)
        else:
            X_aug = augment_data(X=X, feature_metadata=self.feature_metadata,
                                 augmentation_data=augmentation_data, augment_method=augment_method, augment_args=augment_args)
            if len(X_aug) > 0:
                if teacher_preds == 'hard':
                    y_aug = pd.Series(self.predict(X_aug, model=teacher))
                elif teacher_preds == 'soft':
                    y_aug = self.predict_proba(X_aug, model=teacher)
                    if self.problem_type == MULTICLASS:
                        y_aug = pd.DataFrame(y_aug)
                    else:
                        y_aug = pd.Series(y_aug)
                else:
                    raise ValueError(f"Unknown teacher_preds specified: {teacher_preds}")

                X = pd.concat([X, X_aug])
                y = pd.concat([y, y_aug])
                if self.sample_weight is not None:
                     w = pd.concat([w, pd.Series([augmented_sample_weight]*len(X_aug))])

        X.reset_index(drop=True, inplace=True)
        y.reset_index(drop=True, inplace=True)
        if self.sample_weight is not None:
            w.reset_index(drop=True, inplace=True)
            X[self.sample_weight] = w

        name_suffix = '_DSTL'  # all student model names contain this substring
        if models_name_suffix is not None:
            name_suffix = name_suffix + "_" + models_name_suffix

        if hyperparameters is None:
            hyperparameters = {'GBM': {}, 'CAT': {}, 'NN_MXNET': {},  'NN_TORCH': {}, 'RF': {}}
        hyperparameters = self._process_hyperparameters(hyperparameters=hyperparameters)  # TODO: consider exposing ag_args_fit, excluded_model_types as distill() arguments.
        if teacher_preds is not None and teacher_preds != 'hard' and self.problem_type != REGRESSION:
            self._regress_preds_asprobas = True

        core_kwargs = {
            'stack_name': self.distill_stackname,
            'get_models_func': self.construct_model_templates_distillation,
        }
        aux_kwargs = {
            'get_models_func': self.construct_model_templates_distillation,
            'check_if_best': False,
        }

        # self.bagged_mode = True  # TODO: Add options for bagging
        models = self.train_multi_levels(
            X=X,
            y=y,
            X_val=X_val,
            y_val=y_val,
            hyperparameters=hyperparameters,
            time_limit=time_limit,  # FIXME: Also limit augmentation time
            name_suffix=name_suffix,
            core_kwargs=core_kwargs,
            aux_kwargs=aux_kwargs,
        )

        distilled_model_names = []
        w_val = None
        if self.weight_evaluation:
            X_val, w_val = extract_column(X_val, self.sample_weight)
        for model_name in models:  # finally measure original metric on validation data and overwrite stored val_scores
            model_score = self.score(X_val, y_val_og, model=model_name, weights=w_val)
            model_obj = self.load_model(model_name)
            model_obj.val_score = model_score
            model_obj.save()  # TODO: consider omitting for sake of efficiency
            self.model_graph.nodes[model_name]['val_score'] = model_score
            distilled_model_names.append(model_name)
        leaderboard = self.leaderboard()
        logger.log(20, 'Distilled model leaderboard:')
        leaderboard_distilled = leaderboard[leaderboard['model'].isin(models)].reset_index(drop=True)
        with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.width', 1000):
            logger.log(20, leaderboard_distilled)

        # reset trainer to old state before distill() was called:
        self.bagged_mode = og_bagged_mode  # TODO: Confirm if safe to train future models after training models in both bagged and non-bagged modes
        self.verbosity = og_verbosity
        return distilled_model_names