def load_prompts()

in preprocess/utils.py [0:0]


def load_prompts(do_train):
    from promptsource.templates import TemplateCollection

    subtask_dict = {
        "ai2_arc": "ARC-Challenge",
        "codah": "fold_0",
        "hotpot_qa": "disctractor",
        "openbookqa": "main",
        "paws": "labeled_final",
        "scitail": "snli_format"
    }
    with open("../config/hr_to_lr_noinst.json", "r") as f:
        config = json.load(f)
    train_tasks = set(config["train"])
    test_tasks = set(config["test"])

    if do_train:
        test_tasks = set()
    else:
        train_tasks = set()

    collection = TemplateCollection()
    available_tasks = defaultdict(list)
    prompt_names_per_task = {}
    prompt_dict = {}
    for task, subtask in collection.keys:
        if task in train_tasks or task in test_tasks:
            available_tasks[task].append(subtask)

    for task, subtasks in available_tasks.items():
        if len(subtasks)>1:
            subtasks = [subtask_dict[task]]
        assert len(subtasks)==1, (task, subtasks)
        available_tasks[task] = subtasks[0]

    def normalize_name(name):
        return name.replace(" ", "-").replace("/", "-").replace("_", "-")

    for task, subtask in available_tasks.items():
        prompts = collection.get_dataset(task, subtask)

        if do_train:
            prompt_names_per_task[task] = []
            for name in prompts.all_template_names:
                if task=="circa" and name in ["possible_qn", "question_declarative"]:
                    # always give empty output for some reason
                    print ("Skipping", task, name)
                    continue
                prompt_names_per_task[task].append(normalize_name(name))
                prompt_dict[task+":"+normalize_name(name)] = prompts[name]
        else:

            all_template_names = [name for name in prompts.all_template_names if "no_option" not in name]

            if task=="dream":
                all_template_names = ["read_the_following_conversation_and_answer_the_question"]

            for keyword in ["multiple_choice", "most_correct", "most_suitable"]:
                _all_template_names = [name for name in all_template_names if keyword in name]
                if len(_all_template_names)>0:
                    all_template_names = _all_template_names

            if len(all_template_names)<1:
                continue

            prompt = prompts[all_template_names[0]]
            prompt_names_per_task[task] = [all_template_names[0]]
            prompt_dict[task] = prompt

    with open("../config/hr_to_lr_inst_all.json", "r") as f:
        config = json.load(f)
    datasets = [t[5:] for t in config["train" if do_train else "test"]]

    assert set(datasets)==set(prompt_dict.keys()), (len(datasets), len(prompt_dict))

    return prompt_names_per_task, prompt_dict