def train_function()

in text/src/autogluon/text/text_prediction/mx/models.py [0:0]


def train_function(args, reporter, train_df_path, tuning_df_path,
                   time_limit, time_start, base_config,
                   problem_type, column_types,
                   feature_columns, label_column, output_directory,
                   log_metrics, eval_metric, ngpus_per_trial,
                   params_path, preprocessor_path, continue_training,
                   console_log, seed=None, verbosity=2):
    """

    Parameters
    ----------
    args
        The arguments
    reporter
        Reporter of the HPO scheduler.
        If it is set to None, we won't use the reporter and will just run a single trial.
    train_df_path
        Path of the training dataframe
    tuning_df_path
        Path of the tuning dataframe
    time_limit
        The time limit of calling this function
    time_start
        The starting timestamp of the experiment
    base_config
        Basic configuration
    problem_type
        Type of the problem.
    column_types
        Type of columns
    feature_columns
        The feature columns
    label_column
        Label column
    output_directory
        The output directory
    log_metrics
        Metrics for logging
    eval_metric
        The stopping metric
    ngpus_per_trial
        The number of GPUs to use per each trial
    params_path
        The parameter path of the network
    preprocessor_path
        The path to store the preprocessor
    continue_training
        Whether we are loading a model and continue training it on a new dataset
    console_log
        Whether to log it to console
    seed
        The random seed
    verbosity
        The verbosity

    """
    set_seed(seed)
    is_fake_reporter = isinstance(reporter, FakeReporter)
    if time_limit is not None:
        start_train_tick = time.time()
        time_left = time_limit - (start_train_tick - time_start)
        if time_left <= 0:
            if not is_fake_reporter:
                reporter.terminate()
            return
    search_space = args
    if is_fake_reporter:
        task_id = 0
    else:
        task_id = args.pop('task_id')
    # Get the log metric scorers
    if isinstance(log_metrics, str):
        log_metrics = [log_metrics]
    # Load the training and tuning data from the parquet file
    train_data = pd.read_pickle(train_df_path)
    tuning_data = pd.read_pickle(tuning_df_path)
    log_metric_scorers = [get_metric(ele) for ele in log_metrics]
    eval_metric_scorer = get_metric(eval_metric)
    greater_is_better = eval_metric_scorer.greater_is_better
    cfg = base_config.clone()
    specified_values = []
    for key in search_space.keys():
        specified_values.append(key)
        specified_values.append(search_space[key])
    cfg.merge_from_list(specified_values)
    exp_dir = os.path.join(output_directory, 'task{}'.format(task_id))
    os.makedirs(exp_dir, exist_ok=True)
    cfg.defrost()
    cfg.misc.exp_dir = exp_dir
    cfg.freeze()
    logger = logging.getLogger(__name__)
    set_logger_verbosity(verbosity, logger)
    logging_config(folder=exp_dir, name='training', logger=logger, console=console_log,
                   level=logging.DEBUG,
                   console_level=verbosity2loglevel(verbosity))
    logger.log(10, cfg)

    # Load backbone model
    if 'roberta' in cfg.model.backbone.name:
        backbone_model_cls, backbone_cfg, tokenizer, backbone_params_path, _ \
            = get_backbone(cfg.model.backbone.name)
        text_backbone = backbone_model_cls.from_cfg(backbone_cfg, return_all_hiddens=True)
    else:
        backbone_model_cls, backbone_cfg, tokenizer, backbone_params_path, _ \
            = get_backbone(cfg.model.backbone.name, load_backbone=not continue_training)
        text_backbone = backbone_model_cls.from_cfg(backbone_cfg)

    # Build Preprocessor + Preprocess the training dataset + Inference problem type
    # TODO Dynamically cache the preprocessor that has been fitted.
    if continue_training:
        with open(preprocessor_path, 'rb') as in_f:
            preprocessor = pickle.load(in_f)
        train_dataset = preprocessor.transform(train_data[feature_columns],
                                               train_data[label_column])
        label_generator = preprocessor._label_generator
    else:
        if problem_type == MULTICLASS or problem_type == BINARY:
            label_generator = LabelEncoder()
            label_generator.fit(pd.concat([train_data[label_column], tuning_data[label_column]]))
        else:
            label_generator = None
        preprocessor = MultiModalTextFeatureProcessor(column_types=column_types,
                                                      label_column=label_column,
                                                      tokenizer_name=cfg.model.backbone.name,
                                                      label_generator=label_generator,
                                                      cfg=cfg.preprocessing)
        logger.info('Fitting and transforming the train data...')
        train_dataset = preprocessor.fit_transform(train_data[feature_columns],
                                                   train_data[label_column])
    with open(os.path.join(exp_dir, 'preprocessor.pkl'), 'wb') as of:
        pickle.dump(preprocessor, of)
    logger.info(f'Done! Preprocessor saved to {os.path.join(exp_dir, "preprocessor.pkl")}')
    logger.log(10, 'Train Data')
    logger.log(10, get_stats_string(preprocessor, train_dataset, is_train=True))
    logger.info('Process dev set...')
    tuning_dataset = preprocessor.transform(tuning_data[feature_columns],
                                            tuning_data[label_column])
    logger.info('Done!')
    # Auto Max Length
    if cfg.preprocessing.text.auto_max_length:
        max_length = auto_shrink_max_length(
            train_dataset,
            insert_sep=cfg.model.insert_sep,
            num_text_features=len(preprocessor.text_feature_names),
            auto_max_length_quantile=cfg.preprocessing.text.auto_max_length_quantile,
            round_to=cfg.preprocessing.text.auto_max_length_round_to,
            max_length=cfg.preprocessing.text.max_length)
    else:
        max_length = cfg.preprocessing.text.max_length
    train_stochastic_chunk = cfg.model.train_stochastic_chunk
    test_stochastic_chunk = cfg.model.test_stochastic_chunk
    inference_num_repeat = cfg.model.inference_num_repeat
    if max_length < cfg.preprocessing.text.max_length:
        inference_num_repeat = 1
    cfg.defrost()
    cfg.preprocessing.text.max_length = max_length
    cfg.model.inference_num_repeat = inference_num_repeat
    cfg.freeze()
    with open(os.path.join(exp_dir, 'cfg.yml'), 'w') as f:
        f.write(cfg.dump())
    logger.info(f'Max length for chunking text: {max_length}, '
                f'Stochastic chunk: Train-{train_stochastic_chunk}/Test-{test_stochastic_chunk}, '
                f'Test #repeat: {inference_num_repeat}.')
    cls_id, sep_id = get_cls_sep_id(tokenizer)
    train_batchify_fn = MultiModalTextBatchify(
        num_text_inputs=len(preprocessor.text_feature_names),
        num_categorical_inputs=len(preprocessor.categorical_feature_names),
        num_numerical_inputs=len(preprocessor.numerical_feature_names) > 0,
        cls_token_id=cls_id, sep_token_id=sep_id, max_length=max_length,
        mode='train', stochastic_chunk=train_stochastic_chunk,
        insert_sep=cfg.model.insert_sep)
    test_batchify_fn = MultiModalTextBatchify(
        num_text_inputs=len(preprocessor.text_feature_names),
        num_categorical_inputs=len(preprocessor.categorical_feature_names),
        num_numerical_inputs=len(preprocessor.numerical_feature_names) > 0,
        cls_token_id=cls_id, sep_token_id=sep_id, max_length=max_length,
        mode='test', stochastic_chunk=test_stochastic_chunk,
        insert_sep=cfg.model.insert_sep)

    # Get the ground-truth dev labels
    gt_dev_labels = np.array([ele[-1] for ele in tuning_dataset])
    if problem_type == REGRESSION:
        gt_dev_labels = preprocessor.label_scaler.inverse_transform(np.expand_dims(gt_dev_labels,
                                                                                   axis=-1))[:, 0]
    ctx_l = get_mxnet_available_ctx()
    if ngpus_per_trial == 0:
        ctx_l = [mx.cpu()]
    else:
        ctx_l = ctx_l[:ngpus_per_trial]
    base_batch_size = cfg.optimization.per_device_batch_size
    num_accumulated = int(np.ceil(cfg.optimization.batch_size / (base_batch_size * len(ctx_l))))
    inference_base_batch_size = base_batch_size * cfg.optimization.val_batch_size_mult
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=base_batch_size,
                                  shuffle=True,
                                  batchify_fn=train_batchify_fn)
    dev_dataloader = DataLoader(tuning_dataset,
                                batch_size=inference_base_batch_size,
                                shuffle=False,
                                batchify_fn=test_batchify_fn)
    if problem_type == REGRESSION:
        out_shape = 1
    elif problem_type == MULTICLASS:
        out_shape = len(label_generator.classes_)
    elif problem_type == BINARY:
        assert len(label_generator.classes_) == 2
        out_shape = 2
    else:
        raise NotImplementedError
    net = MultiModalWithPretrainedTextNN(
        text_backbone=text_backbone,
        num_text_features=1,
        num_categorical_features=len(preprocessor.categorical_feature_names),
        num_numerical_features=len(preprocessor.numerical_feature_names) > 0,
        numerical_input_units=None if len(preprocessor.numerical_feature_names) == 0 else len(
            preprocessor.numerical_feature_names),
        num_categories=preprocessor.categorical_num_categories,
        get_embedding=False,
        cfg=cfg.model.network,
        out_shape=out_shape)
    if continue_training:
        net.load_parameters(params_path, ctx=ctx_l)
    else:
        net.initialize_with_pretrained_backbone(backbone_params_path, ctx=ctx_l)
    net.hybridize()
    num_total_params, num_total_fixed_params = count_parameters(net.collect_params())
    logger.info('#Total Params/Fixed Params={}/{}'.format(num_total_params,
                                                          num_total_fixed_params))
    # Initialize the optimizer
    updates_per_epoch = int(np.ceil(len(train_dataloader) / (num_accumulated * len(ctx_l))))
    optimizer, optimizer_params, max_update \
        = get_optimizer(cfg.optimization,
                        updates_per_epoch=updates_per_epoch)
    valid_interval = int(math.ceil(cfg.optimization.valid_frequency * updates_per_epoch))
    train_log_interval = int(math.ceil(cfg.optimization.log_frequency * updates_per_epoch))

    if 0 < cfg.optimization.layerwise_lr_decay < 1:
        apply_layerwise_decay(net.text_backbone,
                              cfg.optimization.layerwise_lr_decay,
                              backbone_name=cfg.model.backbone.name)
    freeze_layers(net.text_backbone,
                  backbone_name=cfg.model.backbone.name,
                  num_trainable_layers=cfg.model.num_trainable_layers)

    # Do not apply weight decay to all the LayerNorm and bias
    for _, v in net.collect_params('.*beta|.*gamma|.*bias').items():
        v.wd_mult = 0.0
    params = [p for p in net.collect_params().values() if p.grad_req != 'null']
    trainer = mx.gluon.Trainer(params,
                               optimizer, optimizer_params,
                               update_on_kvstore=False)
    # Set grad_req if gradient accumulation is required
    if num_accumulated > 1:
        logger.log(15, 'Using gradient accumulation.'
                   ' Global batch size = {}'.format(cfg.optimization.batch_size))
        for p in params:
            p.grad_req = 'add'
        net.collect_params().zero_grad()
    train_loop_dataloader = grouper(repeat(train_dataloader), len(ctx_l))
    log_loss_l = [mx.np.array(0.0, dtype=np.float32, ctx=ctx) for ctx in ctx_l]
    log_num_samples_l = [0 for _ in ctx_l]
    logging_start_tick = time.time()
    nbest = cfg.optimization.nbest
    best_performance_score = []   # Stores the best performing checkpoints
    best_performance_update_idx = []   # Stores the update index that reached the best validation performance
    best_score = None
    mx.npx.waitall()
    no_better_rounds = 0
    report_idx = 0
    start_tick = time.time()
    if time_limit is not None:
        time_limit -= start_tick - time_start
        if time_limit <= 0:
            if not is_fake_reporter:
                reporter.terminate()
            return
    best_report_items = None
    report_local_jsonl_f = open(os.path.join(exp_dir, 'results_local.jsonl'), 'w')
    logger.info(f'Local training results will be saved to '
                f'{os.path.join(exp_dir, "results_local.jsonl")}.')
    for update_idx in range(max_update):
        for accum_idx in range(num_accumulated):
            sample_l = next(train_loop_dataloader)
            loss_l = []
            for i, (sample, ctx) in enumerate(zip(sample_l, ctx_l)):
                feature_batch, label_batch = sample
                feature_batch = move_to_ctx(feature_batch, ctx)
                label_batch = move_to_ctx(label_batch, ctx)
                with mx.autograd.record():
                    pred = net(feature_batch)
                    if problem_type == MULTICLASS or problem_type == BINARY:
                        logits = mx.npx.log_softmax(pred, axis=-1)
                        loss = - mx.npx.pick(logits,
                                             mx.np.expand_dims(label_batch, axis=-1))
                    elif problem_type == REGRESSION:
                        loss = mx.np.square(pred - mx.np.expand_dims(label_batch, axis=-1))
                    loss_l.append(loss.mean() / len(ctx_l) / num_accumulated)
                log_loss_l[i] += loss_l[i] * len(ctx_l) * loss.shape[0] * num_accumulated
                log_num_samples_l[i] += loss.shape[0]
            for loss in loss_l:
                loss.backward()
        # Begin to update
        trainer.allreduce_grads()
        total_norm, ratio, is_finite = clip_grad_global_norm(params, cfg.optimization.max_grad_norm)
        if not cfg.model._disable_update:
            trainer.update(1.0, ignore_stale_grad=True)

        # Clear after update
        if num_accumulated > 1:
            net.collect_params().zero_grad()
        if (update_idx + 1) % train_log_interval == 0:
            log_loss = sum([ele.as_in_ctx(ctx_l[0]) for ele in log_loss_l]).asnumpy()
            log_num_samples = sum(log_num_samples_l)
            logger.log(15,
                       '[Iter {}/{}, Epoch {}] train loss={:0.2e}, gnorm={:0.2e}, lr={:0.2e}, #samples processed={},'
                       ' #sample per second={:.2f}. ETA={:.2f}min'
                       .format(update_idx + 1, max_update,
                               int(update_idx / updates_per_epoch),
                               log_loss / log_num_samples, total_norm, trainer.learning_rate,
                               log_num_samples,
                               log_num_samples / (time.time() - logging_start_tick),
                               (time.time() - start_tick) / (update_idx + 1)
                               * (max_update - update_idx - 1) / 60))
            logging_start_tick = time.time()
            log_loss_l = [mx.np.array(0.0, dtype=np.float32, ctx=ctx) for ctx in ctx_l]
            log_num_samples_l = [0 for _ in ctx_l]
        if (update_idx + 1) % valid_interval == 0 or (update_idx + 1) == max_update:
            valid_start_tick = time.time()
            dev_predictions = \
                _classification_regression_predict(net,
                                                   dataloader=dev_dataloader,
                                                   problem_type=problem_type,
                                                   label_scaler=preprocessor.label_scaler,
                                                   has_label=False,
                                                   num_repeat=inference_num_repeat)
            log_scores = [calculate_metric(scorer, gt_dev_labels,
                                           dev_predictions,
                                           problem_type)
                          for scorer in log_metric_scorers]
            dev_score = calculate_metric(eval_metric_scorer, gt_dev_labels,
                                         dev_predictions,
                                         problem_type)
            valid_time_spent = time.time() - valid_start_tick
            find_better = False
            find_topn_better = False
            if len(best_performance_score) < nbest:
                best_performance_score.append(dev_score)
                best_performance_update_idx.append(update_idx + 1)
                net.save_parameters(
                    os.path.join(exp_dir,
                                 f'nbest_model{len(best_performance_score) - 1}.params'))
                find_topn_better = True
                if best_score is None or greater_is_better and dev_score >= best_score\
                        or (not greater_is_better and dev_score <= best_score):
                    find_better = True
                    net.save_parameters(os.path.join(exp_dir, f'best_model.params'))
                    best_score = dev_score
            else:
                # First try to update the top-K
                if greater_is_better:
                    if dev_score >= min(best_performance_score):
                        find_topn_better = True
                        replace_idx = np.argmin(best_performance_score)
                        best_performance_score[replace_idx] = dev_score
                        best_performance_update_idx[replace_idx] = update_idx + 1
                        net.save_parameters(
                            os.path.join(exp_dir, f'nbest_model{replace_idx}.params'))
                        if dev_score >= best_score:
                            find_better = True
                            net.save_parameters(os.path.join(exp_dir, f'best_model.params'))
                            best_score = dev_score

                else:
                    if dev_score <= max(best_performance_score):
                        find_topn_better = True
                        replace_idx = np.argmax(best_performance_score)
                        best_performance_score[replace_idx] = dev_score
                        best_performance_update_idx[replace_idx] = update_idx + 1
                        net.save_parameters(
                            os.path.join(exp_dir, f'nbest_model{replace_idx}.params'))
                        if dev_score <= best_score:
                            find_better = True
                            net.save_parameters(os.path.join(exp_dir, f'best_model.params'))
                            best_score = dev_score
            if not find_better:
                no_better_rounds += 1
            else:
                no_better_rounds = 0
            mx.npx.waitall()
            loss_string = ', '.join(['{}={:0.4e}'.format(metric.name, score)
                                     for score, metric in zip(log_scores, log_metric_scorers)])
            logger.log(25, '[Iter {}/{}, Epoch {}] Validation {}, Time computing validation-score={:.3f}s,'
                       ' Total time spent={:.2f}min. Found improved model={}, Improved top-{} models={}'.format(
                           update_idx + 1, max_update, int(update_idx / updates_per_epoch),
                           loss_string, valid_time_spent, (time.time() - start_tick) / 60,
                           find_better, nbest, find_topn_better))
            if reporter is not None:
                report_items = [('iteration', update_idx + 1),
                                ('report_idx', report_idx + 1),
                                ('epoch', int(update_idx / updates_per_epoch))] + \
                    [(metric.name, score)
                     for score, metric in zip(log_scores, log_metric_scorers)] + \
                    [('find_better', find_better),
                     ('find_new_topn', find_topn_better),
                     ('nbest_stat', json.dumps([best_performance_score,
                                                best_performance_update_idx])),
                     ('elapsed_time', int(time.time() - start_tick))]
                if eval_metric_scorer._sign < 0:
                    report_items.append(('reward_attr', -dev_score))
                else:
                    report_items.append(('reward_attr', dev_score))
                report_items.append(('eval_metric', eval_metric_scorer.name))
                report_items.append(('exp_dir', exp_dir))
                if find_better:
                    best_report_items = report_items
                reporter(**dict(report_items))
                report_local_jsonl_f.write(json.dumps(dict(report_items)) + '\n')
                report_local_jsonl_f.flush()
                report_idx += 1
            if no_better_rounds >= cfg.optimization.early_stopping_patience:
                logger.info('Early stopping patience reached!')
                break
            total_time_spent = time.time() - start_tick
            if time_limit is not None and total_time_spent > time_limit:
                break
    # Average checkpoints
    best_report_items_dict = dict(best_report_items)
    best_report_items_dict['report_idx'] = report_idx + 1
    reporter(**best_report_items_dict)
    report_local_jsonl_f.write(json.dumps(best_report_items_dict) + '\n')
    report_local_jsonl_f.close()