def process_test()

in preprocess/unifiedqa.py [0:0]


def process_test(data_dir, test_tasks, k, seeds):
    def _get_sentences(input_, options):
        sentences = []
        text = input_
        for option in options:
            if option not in text:
                break
            text1, text = text.split(option, 1)
            sentences.append(text1)
        sentences.append(text)
        return [s.strip() for s in sentences if len(s.strip())>0]

    def _prepro(task, split, line):
        line = line.strip()
        try:
            input_, output = line.split("\t")
        except Exception:
            print (line)
            exit()
        if line.count("\\n")==2:
            input_, options, context = input_.split("\\n")
            input_ = input_ + "\\n  " + context + " \\n " + options

        if input_.split("\\n")[-1].strip().startswith("(A)"):
            alphabet_options = list(string.ascii_uppercase)
            alphabet_options = ["(" + option + ")" for option in alphabet_options]
            option_text = input_.split("\\n")[-1].strip()
            options = _get_sentences(option_text, alphabet_options)
            assert output in options
            input_ = " ".join(input_.split("\\n")[:-1])
        elif output in ["yes", "no"]:
            options = []
            pass
        else:
            raise NotImplementedError()
        return json.dumps({"task": prefix+task, "input": input_, "options": options, "output": output})

    for task in test_tasks:
        with open(os.path.join(data_dir, prefix+task, "train.tsv"), "r") as f:
            data = [_prepro(task, "train", line) for line in f if len(line.strip())>0]
        with open(os.path.join(data_dir, prefix+task, "dev.tsv" if task in ["mctest", "multirc", "qasc", "qasc_with_ir"] else "test.tsv"), "r") as f:
            test_data = [_prepro(task, "test", line) for line in f if len(line.strip())>0]

        n_lengths = []
        for dp in data:
            n_lengths.append(len(dp.split(" ")))
        #print (task, len(n_lengths), "%.1f %.1f" % (np.mean(n_lengths), np.quantile(n_lengths, 0.90)))

        for seed in seeds:
            np.random.seed(seed)
            train_data = [data[i] for i in np.random.permutation(range(len(data)))[:k]]
            with open(os.path.join(data_dir, prefix+task, "{}_{}_{}_train.jsonl".format(prefix+task, k, seed)), "w") as f:
                for line in train_data:
                    f.write(line+"\n")
            with open(os.path.join(data_dir, prefix+task, "{}_{}_{}_test.jsonl".format(prefix+task, k, seed)), "w") as f:
                for line in test_data:
                    f.write(line+"\n")
        print ("Finish saving %s\t#=%d" % (task, k))