in text/src/autogluon/text/text_prediction/mx/models.py [0:0]
def train_function(args, reporter, train_df_path, tuning_df_path,
time_limit, time_start, base_config,
problem_type, column_types,
feature_columns, label_column, output_directory,
log_metrics, eval_metric, ngpus_per_trial,
params_path, preprocessor_path, continue_training,
console_log, seed=None, verbosity=2):
"""
Parameters
----------
args
The arguments
reporter
Reporter of the HPO scheduler.
If it is set to None, we won't use the reporter and will just run a single trial.
train_df_path
Path of the training dataframe
tuning_df_path
Path of the tuning dataframe
time_limit
The time limit of calling this function
time_start
The starting timestamp of the experiment
base_config
Basic configuration
problem_type
Type of the problem.
column_types
Type of columns
feature_columns
The feature columns
label_column
Label column
output_directory
The output directory
log_metrics
Metrics for logging
eval_metric
The stopping metric
ngpus_per_trial
The number of GPUs to use per each trial
params_path
The parameter path of the network
preprocessor_path
The path to store the preprocessor
continue_training
Whether we are loading a model and continue training it on a new dataset
console_log
Whether to log it to console
seed
The random seed
verbosity
The verbosity
"""
set_seed(seed)
is_fake_reporter = isinstance(reporter, FakeReporter)
if time_limit is not None:
start_train_tick = time.time()
time_left = time_limit - (start_train_tick - time_start)
if time_left <= 0:
if not is_fake_reporter:
reporter.terminate()
return
search_space = args
if is_fake_reporter:
task_id = 0
else:
task_id = args.pop('task_id')
# Get the log metric scorers
if isinstance(log_metrics, str):
log_metrics = [log_metrics]
# Load the training and tuning data from the parquet file
train_data = pd.read_pickle(train_df_path)
tuning_data = pd.read_pickle(tuning_df_path)
log_metric_scorers = [get_metric(ele) for ele in log_metrics]
eval_metric_scorer = get_metric(eval_metric)
greater_is_better = eval_metric_scorer.greater_is_better
cfg = base_config.clone()
specified_values = []
for key in search_space.keys():
specified_values.append(key)
specified_values.append(search_space[key])
cfg.merge_from_list(specified_values)
exp_dir = os.path.join(output_directory, 'task{}'.format(task_id))
os.makedirs(exp_dir, exist_ok=True)
cfg.defrost()
cfg.misc.exp_dir = exp_dir
cfg.freeze()
logger = logging.getLogger(__name__)
set_logger_verbosity(verbosity, logger)
logging_config(folder=exp_dir, name='training', logger=logger, console=console_log,
level=logging.DEBUG,
console_level=verbosity2loglevel(verbosity))
logger.log(10, cfg)
# Load backbone model
if 'roberta' in cfg.model.backbone.name:
backbone_model_cls, backbone_cfg, tokenizer, backbone_params_path, _ \
= get_backbone(cfg.model.backbone.name)
text_backbone = backbone_model_cls.from_cfg(backbone_cfg, return_all_hiddens=True)
else:
backbone_model_cls, backbone_cfg, tokenizer, backbone_params_path, _ \
= get_backbone(cfg.model.backbone.name, load_backbone=not continue_training)
text_backbone = backbone_model_cls.from_cfg(backbone_cfg)
# Build Preprocessor + Preprocess the training dataset + Inference problem type
# TODO Dynamically cache the preprocessor that has been fitted.
if continue_training:
with open(preprocessor_path, 'rb') as in_f:
preprocessor = pickle.load(in_f)
train_dataset = preprocessor.transform(train_data[feature_columns],
train_data[label_column])
label_generator = preprocessor._label_generator
else:
if problem_type == MULTICLASS or problem_type == BINARY:
label_generator = LabelEncoder()
label_generator.fit(pd.concat([train_data[label_column], tuning_data[label_column]]))
else:
label_generator = None
preprocessor = MultiModalTextFeatureProcessor(column_types=column_types,
label_column=label_column,
tokenizer_name=cfg.model.backbone.name,
label_generator=label_generator,
cfg=cfg.preprocessing)
logger.info('Fitting and transforming the train data...')
train_dataset = preprocessor.fit_transform(train_data[feature_columns],
train_data[label_column])
with open(os.path.join(exp_dir, 'preprocessor.pkl'), 'wb') as of:
pickle.dump(preprocessor, of)
logger.info(f'Done! Preprocessor saved to {os.path.join(exp_dir, "preprocessor.pkl")}')
logger.log(10, 'Train Data')
logger.log(10, get_stats_string(preprocessor, train_dataset, is_train=True))
logger.info('Process dev set...')
tuning_dataset = preprocessor.transform(tuning_data[feature_columns],
tuning_data[label_column])
logger.info('Done!')
# Auto Max Length
if cfg.preprocessing.text.auto_max_length:
max_length = auto_shrink_max_length(
train_dataset,
insert_sep=cfg.model.insert_sep,
num_text_features=len(preprocessor.text_feature_names),
auto_max_length_quantile=cfg.preprocessing.text.auto_max_length_quantile,
round_to=cfg.preprocessing.text.auto_max_length_round_to,
max_length=cfg.preprocessing.text.max_length)
else:
max_length = cfg.preprocessing.text.max_length
train_stochastic_chunk = cfg.model.train_stochastic_chunk
test_stochastic_chunk = cfg.model.test_stochastic_chunk
inference_num_repeat = cfg.model.inference_num_repeat
if max_length < cfg.preprocessing.text.max_length:
inference_num_repeat = 1
cfg.defrost()
cfg.preprocessing.text.max_length = max_length
cfg.model.inference_num_repeat = inference_num_repeat
cfg.freeze()
with open(os.path.join(exp_dir, 'cfg.yml'), 'w') as f:
f.write(cfg.dump())
logger.info(f'Max length for chunking text: {max_length}, '
f'Stochastic chunk: Train-{train_stochastic_chunk}/Test-{test_stochastic_chunk}, '
f'Test #repeat: {inference_num_repeat}.')
cls_id, sep_id = get_cls_sep_id(tokenizer)
train_batchify_fn = MultiModalTextBatchify(
num_text_inputs=len(preprocessor.text_feature_names),
num_categorical_inputs=len(preprocessor.categorical_feature_names),
num_numerical_inputs=len(preprocessor.numerical_feature_names) > 0,
cls_token_id=cls_id, sep_token_id=sep_id, max_length=max_length,
mode='train', stochastic_chunk=train_stochastic_chunk,
insert_sep=cfg.model.insert_sep)
test_batchify_fn = MultiModalTextBatchify(
num_text_inputs=len(preprocessor.text_feature_names),
num_categorical_inputs=len(preprocessor.categorical_feature_names),
num_numerical_inputs=len(preprocessor.numerical_feature_names) > 0,
cls_token_id=cls_id, sep_token_id=sep_id, max_length=max_length,
mode='test', stochastic_chunk=test_stochastic_chunk,
insert_sep=cfg.model.insert_sep)
# Get the ground-truth dev labels
gt_dev_labels = np.array([ele[-1] for ele in tuning_dataset])
if problem_type == REGRESSION:
gt_dev_labels = preprocessor.label_scaler.inverse_transform(np.expand_dims(gt_dev_labels,
axis=-1))[:, 0]
ctx_l = get_mxnet_available_ctx()
if ngpus_per_trial == 0:
ctx_l = [mx.cpu()]
else:
ctx_l = ctx_l[:ngpus_per_trial]
base_batch_size = cfg.optimization.per_device_batch_size
num_accumulated = int(np.ceil(cfg.optimization.batch_size / (base_batch_size * len(ctx_l))))
inference_base_batch_size = base_batch_size * cfg.optimization.val_batch_size_mult
train_dataloader = DataLoader(train_dataset,
batch_size=base_batch_size,
shuffle=True,
batchify_fn=train_batchify_fn)
dev_dataloader = DataLoader(tuning_dataset,
batch_size=inference_base_batch_size,
shuffle=False,
batchify_fn=test_batchify_fn)
if problem_type == REGRESSION:
out_shape = 1
elif problem_type == MULTICLASS:
out_shape = len(label_generator.classes_)
elif problem_type == BINARY:
assert len(label_generator.classes_) == 2
out_shape = 2
else:
raise NotImplementedError
net = MultiModalWithPretrainedTextNN(
text_backbone=text_backbone,
num_text_features=1,
num_categorical_features=len(preprocessor.categorical_feature_names),
num_numerical_features=len(preprocessor.numerical_feature_names) > 0,
numerical_input_units=None if len(preprocessor.numerical_feature_names) == 0 else len(
preprocessor.numerical_feature_names),
num_categories=preprocessor.categorical_num_categories,
get_embedding=False,
cfg=cfg.model.network,
out_shape=out_shape)
if continue_training:
net.load_parameters(params_path, ctx=ctx_l)
else:
net.initialize_with_pretrained_backbone(backbone_params_path, ctx=ctx_l)
net.hybridize()
num_total_params, num_total_fixed_params = count_parameters(net.collect_params())
logger.info('#Total Params/Fixed Params={}/{}'.format(num_total_params,
num_total_fixed_params))
# Initialize the optimizer
updates_per_epoch = int(np.ceil(len(train_dataloader) / (num_accumulated * len(ctx_l))))
optimizer, optimizer_params, max_update \
= get_optimizer(cfg.optimization,
updates_per_epoch=updates_per_epoch)
valid_interval = int(math.ceil(cfg.optimization.valid_frequency * updates_per_epoch))
train_log_interval = int(math.ceil(cfg.optimization.log_frequency * updates_per_epoch))
if 0 < cfg.optimization.layerwise_lr_decay < 1:
apply_layerwise_decay(net.text_backbone,
cfg.optimization.layerwise_lr_decay,
backbone_name=cfg.model.backbone.name)
freeze_layers(net.text_backbone,
backbone_name=cfg.model.backbone.name,
num_trainable_layers=cfg.model.num_trainable_layers)
# Do not apply weight decay to all the LayerNorm and bias
for _, v in net.collect_params('.*beta|.*gamma|.*bias').items():
v.wd_mult = 0.0
params = [p for p in net.collect_params().values() if p.grad_req != 'null']
trainer = mx.gluon.Trainer(params,
optimizer, optimizer_params,
update_on_kvstore=False)
# Set grad_req if gradient accumulation is required
if num_accumulated > 1:
logger.log(15, 'Using gradient accumulation.'
' Global batch size = {}'.format(cfg.optimization.batch_size))
for p in params:
p.grad_req = 'add'
net.collect_params().zero_grad()
train_loop_dataloader = grouper(repeat(train_dataloader), len(ctx_l))
log_loss_l = [mx.np.array(0.0, dtype=np.float32, ctx=ctx) for ctx in ctx_l]
log_num_samples_l = [0 for _ in ctx_l]
logging_start_tick = time.time()
nbest = cfg.optimization.nbest
best_performance_score = [] # Stores the best performing checkpoints
best_performance_update_idx = [] # Stores the update index that reached the best validation performance
best_score = None
mx.npx.waitall()
no_better_rounds = 0
report_idx = 0
start_tick = time.time()
if time_limit is not None:
time_limit -= start_tick - time_start
if time_limit <= 0:
if not is_fake_reporter:
reporter.terminate()
return
best_report_items = None
report_local_jsonl_f = open(os.path.join(exp_dir, 'results_local.jsonl'), 'w')
logger.info(f'Local training results will be saved to '
f'{os.path.join(exp_dir, "results_local.jsonl")}.')
for update_idx in range(max_update):
for accum_idx in range(num_accumulated):
sample_l = next(train_loop_dataloader)
loss_l = []
for i, (sample, ctx) in enumerate(zip(sample_l, ctx_l)):
feature_batch, label_batch = sample
feature_batch = move_to_ctx(feature_batch, ctx)
label_batch = move_to_ctx(label_batch, ctx)
with mx.autograd.record():
pred = net(feature_batch)
if problem_type == MULTICLASS or problem_type == BINARY:
logits = mx.npx.log_softmax(pred, axis=-1)
loss = - mx.npx.pick(logits,
mx.np.expand_dims(label_batch, axis=-1))
elif problem_type == REGRESSION:
loss = mx.np.square(pred - mx.np.expand_dims(label_batch, axis=-1))
loss_l.append(loss.mean() / len(ctx_l) / num_accumulated)
log_loss_l[i] += loss_l[i] * len(ctx_l) * loss.shape[0] * num_accumulated
log_num_samples_l[i] += loss.shape[0]
for loss in loss_l:
loss.backward()
# Begin to update
trainer.allreduce_grads()
total_norm, ratio, is_finite = clip_grad_global_norm(params, cfg.optimization.max_grad_norm)
if not cfg.model._disable_update:
trainer.update(1.0, ignore_stale_grad=True)
# Clear after update
if num_accumulated > 1:
net.collect_params().zero_grad()
if (update_idx + 1) % train_log_interval == 0:
log_loss = sum([ele.as_in_ctx(ctx_l[0]) for ele in log_loss_l]).asnumpy()
log_num_samples = sum(log_num_samples_l)
logger.log(15,
'[Iter {}/{}, Epoch {}] train loss={:0.2e}, gnorm={:0.2e}, lr={:0.2e}, #samples processed={},'
' #sample per second={:.2f}. ETA={:.2f}min'
.format(update_idx + 1, max_update,
int(update_idx / updates_per_epoch),
log_loss / log_num_samples, total_norm, trainer.learning_rate,
log_num_samples,
log_num_samples / (time.time() - logging_start_tick),
(time.time() - start_tick) / (update_idx + 1)
* (max_update - update_idx - 1) / 60))
logging_start_tick = time.time()
log_loss_l = [mx.np.array(0.0, dtype=np.float32, ctx=ctx) for ctx in ctx_l]
log_num_samples_l = [0 for _ in ctx_l]
if (update_idx + 1) % valid_interval == 0 or (update_idx + 1) == max_update:
valid_start_tick = time.time()
dev_predictions = \
_classification_regression_predict(net,
dataloader=dev_dataloader,
problem_type=problem_type,
label_scaler=preprocessor.label_scaler,
has_label=False,
num_repeat=inference_num_repeat)
log_scores = [calculate_metric(scorer, gt_dev_labels,
dev_predictions,
problem_type)
for scorer in log_metric_scorers]
dev_score = calculate_metric(eval_metric_scorer, gt_dev_labels,
dev_predictions,
problem_type)
valid_time_spent = time.time() - valid_start_tick
find_better = False
find_topn_better = False
if len(best_performance_score) < nbest:
best_performance_score.append(dev_score)
best_performance_update_idx.append(update_idx + 1)
net.save_parameters(
os.path.join(exp_dir,
f'nbest_model{len(best_performance_score) - 1}.params'))
find_topn_better = True
if best_score is None or greater_is_better and dev_score >= best_score\
or (not greater_is_better and dev_score <= best_score):
find_better = True
net.save_parameters(os.path.join(exp_dir, f'best_model.params'))
best_score = dev_score
else:
# First try to update the top-K
if greater_is_better:
if dev_score >= min(best_performance_score):
find_topn_better = True
replace_idx = np.argmin(best_performance_score)
best_performance_score[replace_idx] = dev_score
best_performance_update_idx[replace_idx] = update_idx + 1
net.save_parameters(
os.path.join(exp_dir, f'nbest_model{replace_idx}.params'))
if dev_score >= best_score:
find_better = True
net.save_parameters(os.path.join(exp_dir, f'best_model.params'))
best_score = dev_score
else:
if dev_score <= max(best_performance_score):
find_topn_better = True
replace_idx = np.argmax(best_performance_score)
best_performance_score[replace_idx] = dev_score
best_performance_update_idx[replace_idx] = update_idx + 1
net.save_parameters(
os.path.join(exp_dir, f'nbest_model{replace_idx}.params'))
if dev_score <= best_score:
find_better = True
net.save_parameters(os.path.join(exp_dir, f'best_model.params'))
best_score = dev_score
if not find_better:
no_better_rounds += 1
else:
no_better_rounds = 0
mx.npx.waitall()
loss_string = ', '.join(['{}={:0.4e}'.format(metric.name, score)
for score, metric in zip(log_scores, log_metric_scorers)])
logger.log(25, '[Iter {}/{}, Epoch {}] Validation {}, Time computing validation-score={:.3f}s,'
' Total time spent={:.2f}min. Found improved model={}, Improved top-{} models={}'.format(
update_idx + 1, max_update, int(update_idx / updates_per_epoch),
loss_string, valid_time_spent, (time.time() - start_tick) / 60,
find_better, nbest, find_topn_better))
if reporter is not None:
report_items = [('iteration', update_idx + 1),
('report_idx', report_idx + 1),
('epoch', int(update_idx / updates_per_epoch))] + \
[(metric.name, score)
for score, metric in zip(log_scores, log_metric_scorers)] + \
[('find_better', find_better),
('find_new_topn', find_topn_better),
('nbest_stat', json.dumps([best_performance_score,
best_performance_update_idx])),
('elapsed_time', int(time.time() - start_tick))]
if eval_metric_scorer._sign < 0:
report_items.append(('reward_attr', -dev_score))
else:
report_items.append(('reward_attr', dev_score))
report_items.append(('eval_metric', eval_metric_scorer.name))
report_items.append(('exp_dir', exp_dir))
if find_better:
best_report_items = report_items
reporter(**dict(report_items))
report_local_jsonl_f.write(json.dumps(dict(report_items)) + '\n')
report_local_jsonl_f.flush()
report_idx += 1
if no_better_rounds >= cfg.optimization.early_stopping_patience:
logger.info('Early stopping patience reached!')
break
total_time_spent = time.time() - start_tick
if time_limit is not None and total_time_spent > time_limit:
break
# Average checkpoints
best_report_items_dict = dict(best_report_items)
best_report_items_dict['report_idx'] = report_idx + 1
reporter(**best_report_items_dict)
report_local_jsonl_f.write(json.dumps(best_report_items_dict) + '\n')
report_local_jsonl_f.close()