in preprocess/unifiedqa.py [0:0]
def process_test(data_dir, test_tasks, k, seeds):
def _get_sentences(input_, options):
sentences = []
text = input_
for option in options:
if option not in text:
break
text1, text = text.split(option, 1)
sentences.append(text1)
sentences.append(text)
return [s.strip() for s in sentences if len(s.strip())>0]
def _prepro(task, split, line):
line = line.strip()
try:
input_, output = line.split("\t")
except Exception:
print (line)
exit()
if line.count("\\n")==2:
input_, options, context = input_.split("\\n")
input_ = input_ + "\\n " + context + " \\n " + options
if input_.split("\\n")[-1].strip().startswith("(A)"):
alphabet_options = list(string.ascii_uppercase)
alphabet_options = ["(" + option + ")" for option in alphabet_options]
option_text = input_.split("\\n")[-1].strip()
options = _get_sentences(option_text, alphabet_options)
assert output in options
input_ = " ".join(input_.split("\\n")[:-1])
elif output in ["yes", "no"]:
options = []
pass
else:
raise NotImplementedError()
return json.dumps({"task": prefix+task, "input": input_, "options": options, "output": output})
for task in test_tasks:
with open(os.path.join(data_dir, prefix+task, "train.tsv"), "r") as f:
data = [_prepro(task, "train", line) for line in f if len(line.strip())>0]
with open(os.path.join(data_dir, prefix+task, "dev.tsv" if task in ["mctest", "multirc", "qasc", "qasc_with_ir"] else "test.tsv"), "r") as f:
test_data = [_prepro(task, "test", line) for line in f if len(line.strip())>0]
n_lengths = []
for dp in data:
n_lengths.append(len(dp.split(" ")))
#print (task, len(n_lengths), "%.1f %.1f" % (np.mean(n_lengths), np.quantile(n_lengths, 0.90)))
for seed in seeds:
np.random.seed(seed)
train_data = [data[i] for i in np.random.permutation(range(len(data)))[:k]]
with open(os.path.join(data_dir, prefix+task, "{}_{}_{}_train.jsonl".format(prefix+task, k, seed)), "w") as f:
for line in train_data:
f.write(line+"\n")
with open(os.path.join(data_dir, prefix+task, "{}_{}_{}_test.jsonl".format(prefix+task, k, seed)), "w") as f:
for line in test_data:
f.write(line+"\n")
print ("Finish saving %s\t#=%d" % (task, k))