in scripts/adapet/ADAPET/setfit_adapet.py [0:0]
def main(parser):
args = parser.parse_args()
english = args.english
prompt = args.prompt
task_name = args.task_name
pretrained_weight = args.pretrained_weight
adapet_seed = args.seed
multilingual = args.multilingual
print("starting work on {}".format(task_name))
if task_name in AMZ_MULTI_LING:
print('loading multilingual dataset')
dataset = load_dataset(task_name)
lang_star_dict, lang_pattern = multiling_verb_pattern(task_name, english, prompt)
dataset = fix_amzn(dataset, lang_star_dict)
if multilingual == 'all':
dsets = []
for task in AMZ_MULTI_LING:
ds = load_dataset(task, split="train")
dsets.append(ds)
# Create training set and sample for fewshot splits
train_ds = concatenate_datasets(dsets).shuffle(seed=42)
train_ds = fix_train_amzn(train_ds, lang_star_dict)
else:
train_ds = dataset['train']
test_ds = dataset['test']
else:
print('loading single sentence dataset')
train_ds = load_dataset(task_name, split="train")
test_ds = load_dataset(task_name, split="test")
#determine the maximum number of tokens in the label text
if task_name in AMZ_MULTI_LING:
max_tokens = get_max_num_lbl_tok(task_name, train_ds, pretrained_weight, lang_star_dict)
else:
max_tokens = get_max_num_lbl_tok(task_name, train_ds, pretrained_weight, lang_star_dict=None)
num_labs = len(set(train_ds["label"]))
if task_name not in AMZ_MULTI_LING:
lang_pattern = None
for sample_size in SAMPLE_SIZES:
print("begun work on {} sample size : {}".format(task_name, sample_size))
fewshot_ds = create_fewshot_splits(sample_size, train_ds)
for ds_seed, ds in enumerate(fewshot_ds):
current_split_ds = fewshot_ds[ds]
updated_args = json_file_setup(task_name, current_split_ds, lang_pattern, max_tokens, parser)
# call ADAPET
exp_dir = call_adapet(updated_args)
#rewrite the existing "test" dataset with the true test data
jsonl_from_dataset(test_ds, task_name, updated_args, "test")
pred_labels, pred_logits = do_test(exp_dir)
y_true = test_ds["label"]
if task_name in ['SetFit/toxic_conversations']:
if len(pred_logits.shape) == 2:
y_pred = pred_logits[:, 1]
logit_ap = average_precision_score(y_true, y_pred)*100
write_seed_output(pretrained_weight, task_name, sample_size, ds_seed, logit_ap, english, prompt)
elif task_name in ['SetFit/amazon_counterfactual_en']:
mcc = matthews_corrcoef(y_true, pred_labels)*100
write_seed_output(pretrained_weight, task_name, sample_size, ds_seed, mcc, english, prompt)
elif task_name in AMZ_MULTI_LING and multilingual in ['each', 'all']:
mae = mean_absolute_error(y_true, pred_labels)*100
write_seed_output(pretrained_weight+'_'+multilingual, task_name, sample_size, ds_seed, mae, english, prompt)
elif task_name == 'SetFit/amazon_reviews_multi_en' and multilingual == 'en':
multilingual_en(exp_dir, updated_args, pretrained_weight, sample_size, ds_seed)
else:
acc = accuracy_score(y_true, pred_labels)*100
write_seed_output(pretrained_weight, task_name, sample_size, ds_seed, acc, english, prompt)
print("no more work on {} sample size : {}".format(task_name, sample_size))
print("finished work on {}".format(task_name))
print()