scripts/adapet/ADAPET/utilcode.py (176 lines of code) (raw):
import os
import json
import argparse
GLUE_DATASETS = ['SetFit/stsb', 'SetFit/mnli_mm', 'SetFit/mnli', 'SetFit/wnli',
'SetFit/qnli', 'SetFit/mrpc', 'SetFit/rte', 'SetFit/qqp']
AMZ_MULTI_LING = ['SetFit/amazon_reviews_multi_ja','SetFit/amazon_reviews_multi_zh', 'SetFit/amazon_reviews_multi_de',
'SetFit/amazon_reviews_multi_fr', 'SetFit/amazon_reviews_multi_es', 'SetFit/amazon_reviews_multi_en']
INTENT_MULTI_LING = ['SetFit/amazon_massive_intent_ar-SA','SetFit/amazon_massive_intent_es-ES', 'SetFit/amazon_massive_intent_de-DE',
'SetFit/amazon_massive_intent_ja-JP', 'SetFit/amazon_massive_intent_zh-CN', 'SetFit/amazon_massive_intent_ru-RU']
SINGLE_SENT_DATASETS = ['SetFit/sst2', 'SetFit/sst5', 'SetFit/imdb', 'SetFit/subj', 'SetFit/ag_news', 'SetFit/bbc-news',
'SetFit/enron_spam', 'SetFit/student-question-categories', 'SetFit/TREC-QC', 'SetFit/toxic_conversations',
'SetFit/amazon_counterfactual_en', 'SetFit/CR', 'SetFit/SentEval-CR', 'SetFit/emotion', 'SetFit/amazon_polarity', 'SetFit/ade_corpus_v2_classification']
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ('yes', 'true', 't', 'y', '1'):
return True
elif v.lower() in ('no', 'false', 'f', 'n', '0'):
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def write_seed_output(pretrained_weight, task_name, sample_size, ds_seed, metric, english, prompt):
dataset = task_name[7:] #remove setfit/
if dataset in ["toxic_conversations"]:
json_dict = {"measure": "ap", "score": metric}
elif dataset in ["amazon_counterfactual_en"]:
json_dict = {"measure": "matthews_correlation", "score": metric}
elif 'SetFit/' + dataset in AMZ_MULTI_LING:
json_dict = {"measure": "mean_absolute_error", "score": metric}
else:
json_dict = {"measure": "acc", "score": metric}
if 'microsoft/' in pretrained_weight:
pretrained_weight = pretrained_weight.replace('microsoft/', '')
if task_name in SINGLE_SENT_DATASETS:
writefile = 'seed_output/' + pretrained_weight +'/'+ dataset +'/'+ 'train-'+str(sample_size)+'-'+str(ds_seed)+ '/'
else:
if english:
lang = 'eng'
else:
lang = 'lang'
if prompt:
prompting = 'prompt'
else:
prompting = 'no-prompt'
writefile = 'seed_output/' + pretrained_weight + "__"+lang+'_'+prompting +'/'+ dataset +'/'+ 'train-'+str(sample_size)+'-'+str(ds_seed)+ '/'
if not os.path.exists(writefile):
os.makedirs(writefile)
writefile = writefile+'results.json'
with open(writefile, "a") as f:
f.write(json.dumps(json_dict))
def fix_train_amzn(dataset, lang_star_dict):
dataset = dataset.rename_column("label_text", "str_label_text")
label_text =[lang_star_dict[i] for i in dataset['label']]
dataset = dataset.add_column('label_text', label_text)
return dataset
def fix_amzn(dataset, lang_star_dict):
dataset = dataset.rename_column("label_text", "str_label_text")
for split, dset in dataset.items():
label_text =[lang_star_dict[i] for i in dset['label']]
dset = dset.add_column('label_text', label_text)
dataset[split] = dset
return dataset
def fix_intent(task_name, dataset, english):
dataset = dataset.rename_column("label_text", "str_label_text")
if english:
for split, dset in dataset.items():
label_text = []
for txt_lab in dset["str_label_text"]:
label_text.append(txt_lab.replace("_", " "))
dset = dset.add_column('label_text', label_text)
dataset[split] = dset
lang_pattern = '[TEXT1] this is [LBL]'
else:
if task_name == 'SetFit/amazon_massive_intent_zh-CN':
lang_pattern = '[TEXT1] 这是 [LBL]'
dataset = dataset.rename_column("label_text_ch", "label_text")
elif task_name == 'SetFit/amazon_massive_intent_ru-RU':
lang_pattern = '[TEXT1] это [LBL]'
dataset = dataset.rename_column("label_text_ru", "label_text")
elif task_name == 'SetFit/amazon_massive_intent_de-DE':
lang_pattern = '[TEXT1] dies ist [LBL]'
dataset = dataset.rename_column("label_text_de", "label_text")
elif task_name == 'SetFit/amazon_massive_intent_ja-JP':
lang_pattern = '[TEXT1]これは[LBL]だ'
dataset = dataset.rename_column("label_text_jp", "label_text")
elif task_name == 'SetFit/amazon_massive_intent_es-ES':
lang_pattern = '[TEXT1] esto es [LBL]'
dataset = dataset.rename_column("label_text_es", "label_text")
return dataset, lang_pattern
def multiling_verb_pattern(task_name, english, prompt):
assert task_name in AMZ_MULTI_LING
if not english:
if task_name == 'SetFit/amazon_reviews_multi_zh':
lang_star_dict = {0: '1星', 1: '2星', 2: '3星', 3: '4星', 4: '5星'}
lang_pattern = '[TEXT1] 这是 [LBL]'
elif task_name == 'SetFit/amazon_reviews_multi_de':
lang_star_dict = {0: '1 stern', 1: '2 sterne', 2: '3 sterne', 3: '4 sterne', 4: '5 sterne'}
lang_pattern = '[TEXT1] dies ist [LBL]'
elif task_name == 'SetFit/amazon_reviews_multi_fr':
lang_star_dict = {0: '1 étoile', 1: '2 étoiles', 2: '3 étoiles', 3: '4 étoiles', 4: '5 étoiles'}
lang_pattern = '[TEXT1] est noté [LBL]'
elif task_name == 'SetFit/amazon_reviews_multi_ja':
lang_star_dict = {0: '一つ星', 1: '二つ星', 2: '三つ星', 3: '四つ星', 4: '五つ星'}
lang_pattern = '[TEXT1]これは[LBL]だ'
elif task_name == 'SetFit/amazon_reviews_multi_es':
lang_star_dict = {0: '1 estrella', 1: '2 estrellas', 2: '3 estrellas', 3: '4 estrellas', 4: '5 estrellas'}
lang_pattern = '[TEXT1] esto es [LBL]'
else:
lang_star_dict = {0: '1 star', 1: '2 stars', 2: '3 stars', 3: '4 stars', 4: '5 stars'}
lang_pattern = '[TEXT1] this is [LBL]'
if prompt:
return lang_star_dict, lang_pattern
else:
lang_pattern = '[TEXT1] [LBL]'
return lang_star_dict, lang_pattern
def fix_stsb(dataset):
dataset = dataset.rename_column("label", "float_label")
dataset = dataset.rename_column("label_text", "na_label_text")
sim_dict = {0: 'very different', 1: 'different', 2: 'dissimilar', 3: 'somewhat similar', 4: 'similar', 5: 'very similar'}
for split, dset in dataset.items():
if split == 'test':
continue
else:
label = [round(i) for i in dset['float_label']]
dset = dset.add_column("label", label)
label_text = [sim_dict[i] for i in label]
dset = dset.add_column('label_text', label_text)
dataset[split] = dset
return dataset
def write_evaluation_json(accs, mics, macs, avg_pres, logit_aps, num_labs, sample_size, task_name, configs, english, prompt):
if sample_size in ["full", 500]:
assert len(accs) == len(mics) == len(macs) == len(avg_pres) == len(logit_aps) == len(ADAPET_SEEDS)
else:
assert len(accs) == len(mics) == len(macs) == len(avg_pres) == len(logit_aps) == len(SEEDS)
round_to = 10
mean_acc = round(np.mean(accs), round_to)
acc_std = round(np.std(accs), round_to)
mean_micro = round(np.mean(mics), round_to)
micro_std = round(np.std(mics), round_to)
mean_macro = round(np.mean(macs), round_to)
macro_std = round(np.std(macs), round_to)
mean_avg_pre = round(np.mean(avg_pres), round_to)
avg_pre_std = round(np.std(avg_pres), round_to)
mean_logit_ap = round(np.mean(logit_aps), round_to)
logit_ap_std = round(np.std(logit_aps), round_to)
#in the multiclass scenario, average precision is not defined
if num_labs > 2:
mean_avg_pre = 'NA'
avg_pre_std = 'NA'
mean_logit_ap = 'NA'
logit_ap_std = 'NA'
json_dict = {
"mean_acc": mean_acc,
"acc_std": acc_std,
"mean_f1_mic": mean_micro,
"f1_mic_std": micro_std,
"mean_f1_mac": mean_macro,
"f1_mac_std": macro_std,
"mean_avg_pre": mean_avg_pre,
"avg_pre_std": avg_pre_std,
"mean_logit_ap": mean_logit_ap,
"logit_ap_std": logit_ap_std,
}
write_dir = 'results/'+ configs["pretrained_weight"] + '/' + task_name.lower()[7:]
if english:
write_dir = write_dir + '_eng'
else:
write_dir = write_dir + '_lang'
if prompt:
write_dir = write_dir + '_prompt'
else:
write_dir = write_dir + '_no_prompt'
if not os.path.exists(write_dir):
os.makedirs(write_dir)
writefile = write_dir + "/" + str(sample_size) + "_split_results.json"
with open(writefile, "w") as f:
f.write(json.dumps(json_dict) + "\n")