in core/src/autogluon/core/trainer/abstract_trainer.py [0:0]
def distill(self, X=None, y=None, X_val=None, y_val=None, X_unlabeled=None,
time_limit=None, hyperparameters=None, holdout_frac=None, verbosity=None,
models_name_suffix=None, teacher=None, teacher_preds='soft',
augmentation_data=None, augment_method='spunge', augment_args={'size_factor':5,'max_size':int(1e5)},
augmented_sample_weight=1.0):
""" Various distillation algorithms.
Args:
X, y: pd.DataFrame and pd.Series of training data.
If None, original training data used during predictor.fit() will be loaded.
This data is split into train/validation if X_val, y_val are None.
X_val, y_val: pd.DataFrame and pd.Series of validation data.
time_limit, hyperparameters, holdout_frac: defined as in predictor.fit()
teacher (None or str):
If None, uses the model with the highest validation score as the teacher model, otherwise use the specified model name as the teacher.
teacher_preds (None or str): If None, we only train with original labels (no data augmentation, overrides augment_method)
If 'hard', labels are hard teacher predictions given by: teacher.predict()
If 'soft', labels are soft teacher predictions given by: teacher.predict_proba()
Note: 'hard' and 'soft' are equivalent for regression problems.
If augment_method specified, teacher predictions are only used to label augmented data (training data keeps original labels).
To apply label-smoothing: teacher_preds='onehot' will use original training data labels converted to one-hots for multiclass (no data augmentation). # TODO: expose smoothing-hyperparameter.
models_name_suffix (str): Suffix to append to each student model's name, new names will look like: 'MODELNAME_dstl_SUFFIX'
augmentation_data: pd.DataFrame of additional data to use as "augmented data" (does not contain labels).
When specified, augment_method, augment_args are ignored, and this is the only augmented data that is used (teacher_preds cannot be None).
augment_method (None or str): specifies which augmentation strategy to utilize. Options: [None, 'spunge','munge']
If None, no augmentation gets applied.
}
augment_args (dict): args passed into the augmentation function corresponding to augment_method.
augmented_sample_weight (float): Nonnegative value indicating how much to weight augmented samples. This is only considered if sample_weight was initially specified in Predictor.
"""
if verbosity is None:
verbosity = self.verbosity
if teacher is None:
teacher = self._get_best()
hyperparameter_tune = False # TODO: add as argument with scheduler options.
if augmentation_data is not None and teacher_preds is None:
raise ValueError("augmentation_data must be None if teacher_preds is None")
logger.log(20, f"Distilling with teacher='{teacher}', teacher_preds={str(teacher_preds)}, augment_method={str(augment_method)} ...")
if teacher not in self.get_model_names(can_infer=True):
raise AssertionError(f"Teacher model '{teacher}' is not a valid teacher model! Either it does not exist or it cannot infer on new data.\n"
f"Valid teacher models: {self.get_model_names(can_infer=True)}")
if X is None:
if y is not None:
raise ValueError("X cannot be None when y specified.")
X = self.load_X()
X_val = self.load_X_val()
if y is None:
y = self.load_y()
y_val = self.load_y_val()
if X_val is None:
if y_val is not None:
raise ValueError("X_val cannot be None when y_val specified.")
if holdout_frac is None:
holdout_frac = default_holdout_frac(len(X), hyperparameter_tune)
X, X_val, y, y_val = generate_train_test_split(X, y, problem_type=self.problem_type, test_size=holdout_frac)
y_val_og = y_val.copy()
og_bagged_mode = self.bagged_mode
og_verbosity = self.verbosity
self.bagged_mode = False # turn off bagging
self.verbosity = verbosity # change verbosity for distillation
if self.sample_weight is not None:
X, w = extract_column(X, self.sample_weight)
if teacher_preds is None or teacher_preds == 'onehot':
augment_method = None
logger.log(20, "Training students without a teacher model. Set teacher_preds = 'soft' or 'hard' to distill using the best AutoGluon predictor as teacher.")
if teacher_preds in ['onehot','soft']:
y = format_distillation_labels(y, self.problem_type, self.num_classes)
y_val = format_distillation_labels(y_val, self.problem_type, self.num_classes)
if augment_method is None and augmentation_data is None:
if teacher_preds == 'hard':
y_pred = pd.Series(self.predict(X, model=teacher))
if (self.problem_type != REGRESSION) and (len(y_pred.unique()) < len(y.unique())): # add missing labels
logger.log(15, "Adding missing labels to distillation dataset by including some real training examples")
indices_to_add = []
for clss in y.unique():
if clss not in y_pred.unique():
logger.log(15, f"Fetching a row with label={clss} from training data")
clss_index = y[y == clss].index[0]
indices_to_add.append(clss_index)
X_extra = X.loc[indices_to_add].copy()
y_extra = y.loc[indices_to_add].copy() # these are actually real training examples
X = pd.concat([X, X_extra])
y_pred = pd.concat([y_pred, y_extra])
if self.sample_weight is not None:
w = pd.concat([w, w[indices_to_add]])
y = y_pred
elif teacher_preds == 'soft':
y = self.predict_proba(X, model=teacher)
if self.problem_type == MULTICLASS:
y = pd.DataFrame(y)
else:
y = pd.Series(y)
else:
X_aug = augment_data(X=X, feature_metadata=self.feature_metadata,
augmentation_data=augmentation_data, augment_method=augment_method, augment_args=augment_args)
if len(X_aug) > 0:
if teacher_preds == 'hard':
y_aug = pd.Series(self.predict(X_aug, model=teacher))
elif teacher_preds == 'soft':
y_aug = self.predict_proba(X_aug, model=teacher)
if self.problem_type == MULTICLASS:
y_aug = pd.DataFrame(y_aug)
else:
y_aug = pd.Series(y_aug)
else:
raise ValueError(f"Unknown teacher_preds specified: {teacher_preds}")
X = pd.concat([X, X_aug])
y = pd.concat([y, y_aug])
if self.sample_weight is not None:
w = pd.concat([w, pd.Series([augmented_sample_weight]*len(X_aug))])
X.reset_index(drop=True, inplace=True)
y.reset_index(drop=True, inplace=True)
if self.sample_weight is not None:
w.reset_index(drop=True, inplace=True)
X[self.sample_weight] = w
name_suffix = '_DSTL' # all student model names contain this substring
if models_name_suffix is not None:
name_suffix = name_suffix + "_" + models_name_suffix
if hyperparameters is None:
hyperparameters = {'GBM': {}, 'CAT': {}, 'NN_MXNET': {}, 'NN_TORCH': {}, 'RF': {}}
hyperparameters = self._process_hyperparameters(hyperparameters=hyperparameters) # TODO: consider exposing ag_args_fit, excluded_model_types as distill() arguments.
if teacher_preds is not None and teacher_preds != 'hard' and self.problem_type != REGRESSION:
self._regress_preds_asprobas = True
core_kwargs = {
'stack_name': self.distill_stackname,
'get_models_func': self.construct_model_templates_distillation,
}
aux_kwargs = {
'get_models_func': self.construct_model_templates_distillation,
'check_if_best': False,
}
# self.bagged_mode = True # TODO: Add options for bagging
models = self.train_multi_levels(
X=X,
y=y,
X_val=X_val,
y_val=y_val,
hyperparameters=hyperparameters,
time_limit=time_limit, # FIXME: Also limit augmentation time
name_suffix=name_suffix,
core_kwargs=core_kwargs,
aux_kwargs=aux_kwargs,
)
distilled_model_names = []
w_val = None
if self.weight_evaluation:
X_val, w_val = extract_column(X_val, self.sample_weight)
for model_name in models: # finally measure original metric on validation data and overwrite stored val_scores
model_score = self.score(X_val, y_val_og, model=model_name, weights=w_val)
model_obj = self.load_model(model_name)
model_obj.val_score = model_score
model_obj.save() # TODO: consider omitting for sake of efficiency
self.model_graph.nodes[model_name]['val_score'] = model_score
distilled_model_names.append(model_name)
leaderboard = self.leaderboard()
logger.log(20, 'Distilled model leaderboard:')
leaderboard_distilled = leaderboard[leaderboard['model'].isin(models)].reset_index(drop=True)
with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.width', 1000):
logger.log(20, leaderboard_distilled)
# reset trainer to old state before distill() was called:
self.bagged_mode = og_bagged_mode # TODO: Confirm if safe to train future models after training models in both bagged and non-bagged modes
self.verbosity = og_verbosity
return distilled_model_names