def main()

in scripts/adapet/ADAPET/setfit_adapet.py [0:0]


def main(parser):
    args = parser.parse_args()
    english = args.english
    prompt = args.prompt  
    task_name = args.task_name
    pretrained_weight = args.pretrained_weight
    adapet_seed = args.seed
    multilingual = args.multilingual
    print("starting work on {}".format(task_name))        
    
    if task_name in AMZ_MULTI_LING:
        print('loading multilingual dataset')
        dataset = load_dataset(task_name)
        lang_star_dict, lang_pattern = multiling_verb_pattern(task_name, english, prompt)
        dataset = fix_amzn(dataset, lang_star_dict)
        if multilingual == 'all':
            dsets = []
            for task in AMZ_MULTI_LING:
                ds = load_dataset(task, split="train")
                dsets.append(ds)
            # Create training set and sample for fewshot splits
            train_ds = concatenate_datasets(dsets).shuffle(seed=42)
            train_ds = fix_train_amzn(train_ds, lang_star_dict)
        else:
            train_ds = dataset['train']
        
        test_ds = dataset['test']
    
    else:
        print('loading single sentence dataset')      
        train_ds = load_dataset(task_name, split="train")
        test_ds = load_dataset(task_name, split="test")  
    
    #determine the maximum number of tokens in the label text
    if task_name in AMZ_MULTI_LING:
        max_tokens = get_max_num_lbl_tok(task_name, train_ds, pretrained_weight, lang_star_dict)
    else: 
        max_tokens = get_max_num_lbl_tok(task_name, train_ds, pretrained_weight, lang_star_dict=None)

    num_labs = len(set(train_ds["label"]))
    
    if task_name not in AMZ_MULTI_LING:
        lang_pattern = None
    
    for sample_size in SAMPLE_SIZES:
        print("begun work on {} sample size : {}".format(task_name, sample_size))
        fewshot_ds = create_fewshot_splits(sample_size, train_ds)
        for ds_seed, ds in enumerate(fewshot_ds):                  
            current_split_ds = fewshot_ds[ds]                
        
            updated_args = json_file_setup(task_name, current_split_ds, lang_pattern, max_tokens, parser)

            # call ADAPET
            exp_dir = call_adapet(updated_args)
            
            #rewrite the existing "test" dataset with the true test data 
            jsonl_from_dataset(test_ds, task_name, updated_args, "test")
            
            pred_labels, pred_logits = do_test(exp_dir)
            
            y_true = test_ds["label"]

            if task_name in ['SetFit/toxic_conversations']:
                if len(pred_logits.shape) == 2:
                    y_pred = pred_logits[:, 1]
                    logit_ap = average_precision_score(y_true, y_pred)*100
                    write_seed_output(pretrained_weight, task_name, sample_size, ds_seed, logit_ap, english, prompt)
            
            elif task_name in ['SetFit/amazon_counterfactual_en']:
                mcc = matthews_corrcoef(y_true, pred_labels)*100
                write_seed_output(pretrained_weight, task_name, sample_size, ds_seed, mcc, english, prompt)
            
            elif task_name in AMZ_MULTI_LING and multilingual in ['each', 'all']:
                mae = mean_absolute_error(y_true, pred_labels)*100
                write_seed_output(pretrained_weight+'_'+multilingual, task_name, sample_size, ds_seed, mae, english, prompt)
            
            elif task_name == 'SetFit/amazon_reviews_multi_en' and multilingual == 'en':
                multilingual_en(exp_dir, updated_args, pretrained_weight, sample_size, ds_seed)
            
            else:
                acc = accuracy_score(y_true, pred_labels)*100
                write_seed_output(pretrained_weight, task_name, sample_size, ds_seed, acc, english, prompt)
                    
        print("no more work on {} sample size : {}".format(task_name, sample_size))
    print("finished work on {}".format(task_name))
    print()