in main.py [0:0]
def main(args):
"""Main function for training and testing representations."""
trainers_return = dict()
datasets_return = dict()
# ARGS
update_config_(args)
set_seed(args.seed)
# DATASET
datasets = get_datasets(args)
datasets_trnsf = prepare_transformer_datasets(args, datasets)
update_config_datasets_(args, datasets_trnsf)
# TRANSFORMER (i.e. Encoder)
Transformer = get_Transformer(args, datasets_trnsf)
if args.is_precompute_trnsf:
name = "transformer"
trainers_return[name] = fit_evaluate_trainer(
Transformer, args, name, datasets_trnsf, True
)
datasets_return[name] = prepare_return_datasets(datasets_trnsf)
else:
# loading the pretrained transformer
transformer = fit_trainer(
Transformer,
args,
datasets_trnsf,
True,
"transformer",
is_load_criterion=False,
)
datasets = prepare_classification_datasets_(args, datasets)
for Classifier, clf_name in gen_Classifiers_name(args, transformer, datasets):
trainers_return[clf_name] = fit_evaluate_trainer(
Classifier,
args,
clf_name,
datasets,
False,
is_return_init=args.is_correlation_Bob,
)
datasets_return[clf_name] = prepare_return_datasets(datasets)
if args.is_return:
return trainers_return, datasets_return