in preprocess/utils.py [0:0]
def load_prompts(do_train):
from promptsource.templates import TemplateCollection
subtask_dict = {
"ai2_arc": "ARC-Challenge",
"codah": "fold_0",
"hotpot_qa": "disctractor",
"openbookqa": "main",
"paws": "labeled_final",
"scitail": "snli_format"
}
with open("../config/hr_to_lr_noinst.json", "r") as f:
config = json.load(f)
train_tasks = set(config["train"])
test_tasks = set(config["test"])
if do_train:
test_tasks = set()
else:
train_tasks = set()
collection = TemplateCollection()
available_tasks = defaultdict(list)
prompt_names_per_task = {}
prompt_dict = {}
for task, subtask in collection.keys:
if task in train_tasks or task in test_tasks:
available_tasks[task].append(subtask)
for task, subtasks in available_tasks.items():
if len(subtasks)>1:
subtasks = [subtask_dict[task]]
assert len(subtasks)==1, (task, subtasks)
available_tasks[task] = subtasks[0]
def normalize_name(name):
return name.replace(" ", "-").replace("/", "-").replace("_", "-")
for task, subtask in available_tasks.items():
prompts = collection.get_dataset(task, subtask)
if do_train:
prompt_names_per_task[task] = []
for name in prompts.all_template_names:
if task=="circa" and name in ["possible_qn", "question_declarative"]:
# always give empty output for some reason
print ("Skipping", task, name)
continue
prompt_names_per_task[task].append(normalize_name(name))
prompt_dict[task+":"+normalize_name(name)] = prompts[name]
else:
all_template_names = [name for name in prompts.all_template_names if "no_option" not in name]
if task=="dream":
all_template_names = ["read_the_following_conversation_and_answer_the_question"]
for keyword in ["multiple_choice", "most_correct", "most_suitable"]:
_all_template_names = [name for name in all_template_names if keyword in name]
if len(_all_template_names)>0:
all_template_names = _all_template_names
if len(all_template_names)<1:
continue
prompt = prompts[all_template_names[0]]
prompt_names_per_task[task] = [all_template_names[0]]
prompt_dict[task] = prompt
with open("../config/hr_to_lr_inst_all.json", "r") as f:
config = json.load(f)
datasets = [t[5:] for t in config["train" if do_train else "test"]]
assert set(datasets)==set(prompt_dict.keys()), (len(datasets), len(prompt_dict))
return prompt_names_per_task, prompt_dict