def main()

in main.py [0:0]


def main(args):
    """Main function for training and testing representations."""

    trainers_return = dict()
    datasets_return = dict()

    # ARGS
    update_config_(args)
    set_seed(args.seed)

    # DATASET
    datasets = get_datasets(args)
    datasets_trnsf = prepare_transformer_datasets(args, datasets)
    update_config_datasets_(args, datasets_trnsf)

    # TRANSFORMER (i.e. Encoder)
    Transformer = get_Transformer(args, datasets_trnsf)

    if args.is_precompute_trnsf:

        name = "transformer"
        trainers_return[name] = fit_evaluate_trainer(
            Transformer, args, name, datasets_trnsf, True
        )
        datasets_return[name] = prepare_return_datasets(datasets_trnsf)

    else:

        # loading the pretrained transformer
        transformer = fit_trainer(
            Transformer,
            args,
            datasets_trnsf,
            True,
            "transformer",
            is_load_criterion=False,
        )

        datasets = prepare_classification_datasets_(args, datasets)

        for Classifier, clf_name in gen_Classifiers_name(args, transformer, datasets):

            trainers_return[clf_name] = fit_evaluate_trainer(
                Classifier,
                args,
                clf_name,
                datasets,
                False,
                is_return_init=args.is_correlation_Bob,
            )
            datasets_return[clf_name] = prepare_return_datasets(datasets)

    if args.is_return:
        return trainers_return, datasets_return