in text/src/autogluon/text/text_prediction/predictor/predictor.py [0:0]
def fit(self,
train_data,
tuning_data=None,
time_limit=None,
presets=None,
hyperparameters=None,
column_types=None,
num_cpus=None,
num_gpus=None,
num_trials=None,
plot_results=None,
holdout_frac=None,
save_path=None,
seed=0):
"""
Fit Transformer models to predict label column of a data table based on the other columns (which may contain text or numeric/categorical features).
Parameters
----------
train_data : str or :class:`TabularDataset` or :class:`pd.DataFrame`
Table of the training data, which is similar to a pandas DataFrame.
If str is passed, `train_data` will be loaded using the str value as the file path.
tuning_data : str or :class:`TabularDataset` or :class:`pd.DataFrame`, default = None
Another dataset containing validation data reserved for tuning processes such as early stopping and hyperparameter tuning.
This dataset should be in the same format as `train_data`.
If str is passed, `tuning_data` will be loaded using the str value as the file path.
Note: final model returned may be fit on `tuning_data` as well as `train_data`. Do not provide your evaluation test data here!
If `tuning_data = None`, `fit()` will automatically hold out some random validation examples from `train_data`.
time_limit : int, default = None
Approximately how long `fit()` should run for (wallclock time in seconds).
If not specified, `fit()` will run until the model has completed training.
presets : str, default = None
Presets are pre-registered configurations that control training (hyperparameters and other aspects).
It is recommended to specify presets and avoid specifying most other `fit()` arguments or model hyperparameters prior to becoming familiar with AutoGluon.
Print all available presets via `autogluon.text.list_presets()`.
Some notable presets include:
- "best_quality": produce the most accurate overall predictor (regardless of its efficiency).
- "medium_quality_faster_train": produce an accurate predictor but take efficiency into account (this is the default preset).
- "lower_quality_fast_train": produce a predict that is quick to train and make predictions with, even if its accuracy is worse.
hyperparameters : dict, default = None
The hyperparameters of the `fit()` function, which affect the resulting accuracy of the trained predictor.
Experienced AutoGluon users can use this argument to specify neural network hyperparameter values/search-spaces as well as which hyperparameter-tuning strategy should be employed. See the "Text Prediction" tutorials for examples.
column_types : dict, default = None
The type of data in each table column can be specified via a dictionary that maps the column name to its data type.
For example: `column_types = {"item_name": "text", "brand": "text", "product_description": "text", "height": "numerical"}` may be used for a table with columns: "item_name", "brand", "product_description", and "height".
If None, column_types will be automatically inferred from the data.
The current supported types are:
- "text": each row in this column contains text (sentence, paragraph, etc.).
- "numerical": each row in this column contains a number.
- "categorical": each row in this column belongs to one of K categories.
num_cpus : int, default = None
The number of CPUs to use for each training run (i.e. one hyperparameter-tuning trial).
num_gpus : int, default = None
The number of GPUs to use to use for each training run (i.e. one hyperparameter-tuning trial). We recommend at least 1 GPU for TextPredictor as its neural network models are computationally intensive.
num_trials : int, default = None
If hyperparameter-tuning is used, specifies how many HPO trials should be run (assuming `time_limit` has not been exceeded).
By default, this is the provided number of trials in the `hyperparameters` or `presets`.
If specified here, this value will overwrite the value in `hyperparameters['tune_kwargs']['num_trials']`.
plot_results : bool, default = None
Whether to plot intermediate results from training. If None, will be decided based on the environment in which `fit()` is run.
holdout_frac : float, default = None
Fraction of train_data to holdout as tuning data for optimizing hyperparameters (ignored unless `tuning_data = None`).
Default value (if None) is selected based on the number of rows in the training data and whether hyperparameter-tuning is utilized.
save_path : str, default = None
The path for auto-saving the models' weights
seed : int, default = 0
The random seed to use for this training run. If None, no seed will be specified and repeated runs will produce different results.
Returns
-------
:class:`TextPredictor` object. Returns self.
"""
assert self._fit_called is False
is_continue_training = self._model is not None
verbosity = self.verbosity
if verbosity is None:
verbosity = 3
if save_path is not None:
self._path = setup_outputdir(save_path, warn_if_exist=True)
if is_continue_training:
# We have entered the continue training / transfer learning setting because the model is not None.
logger.info('Continue training the existing model...')
assert presets is None, 'presets is not supported in the continue training setting.'
flat_dict = self._model.config.to_flat_dict()
flat_dict['optimization.lr'] = space.Categorical(flat_dict['optimization.lr'])
existing_hparams = {'models': {'MultimodalTextModel': {'search_space': flat_dict}}}
existing_hparams = merge_params(ag_text_presets.create('default'), existing_hparams)
hyperparameters = merge_params(existing_hparams, hyperparameters)
# Check that the merged hyperparameters matches with the existing hyperparameters.
# Here, we ensure that the model configurations remain the same.
for key in hyperparameters['models']['MultimodalTextModel']['search_space']:
if key in existing_hparams and (key.startswith('model.') or key.startswith('preprocessing.')):
new_value = hyperparameters['models']['MultimodalTextModel']['search_space'][key]
old_value = existing_hparams['models']['MultimodalTextModel']['search_space'][key]
assert new_value == old_value,\
f'The model architecture / preprocessing logic is not allowed to change in the ' \
f'continue training mode. ' \
f'"{key}" is changed to be "{new_value}" from "{old_value}". ' \
f'Please check the specified hyperparameters = {hyperparameters}'
else:
if presets is not None:
preset_hparams = ag_text_presets.create(presets)
else:
preset_hparams = ag_text_presets.create('default')
hyperparameters = merge_params(preset_hparams, hyperparameters)
if num_trials is not None:
hyperparameters['tune_kwargs']['num_trials'] = num_trials
if isinstance(self._label, str):
label_columns = [self._label]
else:
label_columns = list(self._label)
# Get the training and tuning data as pandas dataframe
if isinstance(train_data, str):
train_data = load_pd.load(train_data)
if not isinstance(train_data, pd.DataFrame):
raise AssertionError(f'train_data is required to be a pandas DataFrame, but was instead: {type(train_data)}')
all_columns = list(train_data.columns)
feature_columns = [ele for ele in all_columns if ele not in label_columns]
train_data = train_data[all_columns]
# Get tuning data
if tuning_data is not None:
if isinstance(tuning_data, str):
tuning_data = load_pd.load(tuning_data)
if not isinstance(tuning_data, pd.DataFrame):
raise AssertionError(f'tuning_data is required to be a pandas DataFrame, but was instead: {type(tuning_data)}')
tuning_data = tuning_data[all_columns]
else:
if holdout_frac is None:
num_trials = hyperparameters['tune_kwargs']['num_trials']
if num_trials == 1:
holdout_frac = default_holdout_frac(len(train_data), False)
else:
# For HPO, we will need to use a larger held-out ratio
holdout_frac = default_holdout_frac(len(train_data), True)
train_data, tuning_data = train_test_split(train_data,
test_size=holdout_frac,
random_state=np.random.RandomState(seed))
if is_continue_training:
assert set(label_columns) == set(self._model.label_columns),\
f'Label columns do not match. Inferred label column from data = {set(label_columns)}.' \
f' Label column in model = {set(self._model.label_columns)}'
for col_name in self._model.feature_columns:
assert col_name in feature_columns, f'In the loaded model, "{col_name}" is a feature column,' \
f' but there is ' \
f'no such column in the DataFrame.'
model_hparams = hyperparameters['models']['MultimodalTextModel']
if plot_results is None:
plot_results = in_ipynb()
self._model.train(train_data=train_data,
tuning_data=tuning_data,
num_cpus=num_cpus,
num_gpus=num_gpus,
search_space=model_hparams['search_space'],
tune_kwargs=hyperparameters['tune_kwargs'],
time_limit=time_limit,
continue_training=True,
seed=seed,
plot_results=plot_results,
verbosity=verbosity)
else:
column_types, problem_type = infer_column_problem_types(train_data, tuning_data,
label_columns=label_columns,
problem_type=self._problem_type,
provided_column_types=column_types)
self._eval_metric, log_metrics = infer_eval_log_metrics(problem_type=problem_type,
eval_metric=self._eval_metric)
has_text_column = False
for k, v in column_types.items():
if v == _C.TEXT:
has_text_column = True
break
if not has_text_column:
raise AssertionError('No Text Column is found! This is currently not supported by '
'the TextPredictor. You may try to use '
'autogluon.tabular.TabularPredictor.\n'
'The inferred column properties of the training data is {}'
.format(column_types))
logger.info('Problem Type="{}"'.format(problem_type))
logger.info(printable_column_type_string(column_types))
self._problem_type = problem_type
if 'models' not in hyperparameters or 'MultimodalTextModel' not in hyperparameters['models']:
raise ValueError('The current TextPredictor only supports "MultimodalTextModel" '
'and you must ensure that '
'hyperparameters["models"]["MultimodalTextModel"] can be accessed.')
model_hparams = hyperparameters['models']['MultimodalTextModel']
self._backend = model_hparams['backend']
if plot_results is None:
plot_results = in_ipynb()
if self._backend == 'gluonnlp_v0':
import warnings
warnings.filterwarnings('ignore', module='mxnet')
from ..mx.models import MultiModalTextModel
self._model = MultiModalTextModel(column_types=column_types,
feature_columns=feature_columns,
label_columns=label_columns,
problem_type=self._problem_type,
eval_metric=self._eval_metric,
log_metrics=log_metrics,
output_directory=self._path)
self._model.train(train_data=train_data,
tuning_data=tuning_data,
num_cpus=num_cpus,
num_gpus=num_gpus,
search_space=model_hparams['search_space'],
tune_kwargs=hyperparameters['tune_kwargs'],
time_limit=time_limit,
seed=seed,
plot_results=plot_results,
verbosity=verbosity)
else:
raise NotImplementedError("Currently, we only support using "
"the autogluon-contrib-nlp and MXNet "
"as the backend of AutoGluon-Text. In the future, "
"we will support other models.")
logger.info(f'Training completed. Auto-saving to "{self.path}". '
f'For loading the model, you can use'
f' `predictor = TextPredictor.load("{self.path}")`')
self.save(self.path)
return self