in preprocess/unifiedqa.py [0:0]
def process_train(data_dir, train_tasks, max_length, k, seed):
for task in train_tasks:
np.random.seed(seed)
data, lines = [], []
with open(os.path.join(data_dir, prefix+task, "train.tsv"), "r") as f:
for line in f:
lines.append(line)
for line in tqdm(lines):
try:
input_, output_ = line.strip().split("\t")
except Exception:
continue
if task in ["natural_questions_with_dpr_para", "race_string", "drop", "newsqa",
"narrativeqa", "quoref", "ropes"]:
if normalize_answer(output_) not in normalize_answer(input_):
continue
if task in ["race_string", "social_iqa"]:
in1, in2, in3 = input_.split("\\n")
input_ = in1 + "\\n" + in3
else:
assert input_.count("\\n")<3
if task in ["natural_questions_with_dpr_para"]:
input_ = input_.replace(" , ", ", ").replace("( ", "(").replace(" )", ")").replace(" - - ", " - ").replace(" . ", ". ")
if task in ["natural_questions_with_dpr_para", "race_string", "drop", "newsqa",
"narrativeqa", "quoref", "ropes"] and len(input_.split(" "))>max_length:
question, context = input_.split("\\n")
if normalize_answer(output_) not in normalize_answer(context):
#print (task)
#print (question)
#print (output_)
#print (context[:100])
continue
n_words_question = len(question.split(" "))
n_words_context = len(context.split(" "))
n_words = max_length - n_words_question - 1
assert n_words_context > n_words
n_tries = 0
while True:
start = np.random.choice(range(n_words_context-n_words+1))
new_context = " ".join(context.split(" ")[start:start+n_words])
if normalize_answer(output_) in normalize_answer(new_context):
input_ = question + " \\t " + new_context
break
n_tries += 1
#if n_tries % 1000 == 0:
# print (n_tries, start, n_words_context, n_words)
if len(output_.split(" "))>100:
continue
data.append(input_+"\t"+output_)
data = [data[i] for i in np.random.permutation(range(len(data)))[:k]]
lengths = []
with open(os.path.join(data_dir, prefix+task, "{}_{}_{}_train.jsonl".format(prefix+task, k, seed)), "w") as f:
for line in data:
input_, output = line.split("\t")
f.write(json.dumps({"task": prefix+task, "options": [], "input": input_, "output": output})+"\n")
lengths.append(len(input_.split(" ")))
print ("Finish saving %s\t#=%d" % (task, len(data)))