in preprocess/data_prepro_clean.py [0:0]
def preprecess_QA_generation_newsqa_squad(input_dir,
output_dir,
encoder_json="/home/ec2-user/fairseq/encoder.json",
vocab_bpe="/home/ec2-user/fairseq/vocab.bpe",
only_squad=False):
# use '50009' for the special dictionary token to separate question and answers since
# this token is not encountered in bpe outputs
def _process_data(d, data_source, bpe, source_f, source_bpe_f, target_f, target_bpe_f):
if data_source == 'newsqa':
source = d['text'].strip()
for q in d['questions']:
if 'consensus' in q and 'q' in q and 's' in q['consensus']:
question = q['q'].strip()
answer_s = q['consensus']['s']
answer_e = q['consensus']['e']
answer = source[answer_s:answer_e].strip()
truncated_source_bpe, truncated_source, question_answer_bpe = \
_format_question_answers_bpe(bpe, source, question, answer, special_token_id)
if truncated_source is None or answer_e >= len(truncated_source): # skip the question as answer span was truncated in source
continue
source_f.write(truncated_source.encode('unicode-escape').decode().replace('\\\\', '\\') + '\n')
source_bpe_f.write(' '.join(map(str, truncated_source_bpe)) + '\n')
target_f.write(bpe.decode(question_answer_bpe) + '\n')
target_bpe_f.write(' '.join(map(str, question_answer_bpe)) + '\n')
elif data_source == 'squad':
for paragraph in d['paragraphs']:
context = paragraph['context']
for qa in paragraph['qas']:
question = qa['question'].strip()
ans_set = set()
for ans in qa['answers']:
if ans['text'] not in ans_set:
ans_set.add(ans['text'])
truncated_source_bpe, truncated_source, question_answer_bpe = \
_format_question_answers_bpe(bpe, context, question, ans['text'], special_token_id)
if truncated_source is None: # skip the question
continue
source_f.write(
truncated_source.encode('unicode-escape').decode().replace('\\\\', '\\') + '\n')
source_bpe_f.write(' '.join(map(str, truncated_source_bpe)) + '\n')
target_f.write(bpe.decode(question_answer_bpe) + '\n')
target_bpe_f.write(' '.join(map(str, question_answer_bpe)) + '\n')
else:
raise Exception("data_source must be squad or newsqa!")
special_token_id = 50009
from fairseq.data.encoders.gpt2_bpe import get_encoder
bpe = get_encoder(encoder_json, vocab_bpe)
if not only_squad:
input_json = os.path.join(input_dir, 'combined-newsqa-data-v1.json')
with open(input_json, 'r') as f:
newsqa = json.load(f)
with open(os.path.join(output_dir, 'train.source'), 'w') as train_source_f, \
open(os.path.join(output_dir, 'train.target'), 'w') as train_target_f, \
open(os.path.join(output_dir, 'train.bpe.source'), 'w') as train_source_bpe_f, \
open(os.path.join(output_dir, 'train.bpe.target'), 'w') as train_target_bpe_f, \
open(os.path.join(output_dir, 'val.source'), 'w') as val_source_f, \
open(os.path.join(output_dir, 'val.target'), 'w') as val_target_f, \
open(os.path.join(output_dir, 'val.bpe.source'), 'w') as val_source_bpe_f, \
open(os.path.join(output_dir, 'val.bpe.target'), 'w') as val_target_bpe_f, \
open(os.path.join(output_dir, 'test.source'), 'w') as test_source_f, \
open(os.path.join(output_dir, 'test.target'), 'w') as test_target_f, \
open(os.path.join(output_dir, 'test.bpe.source'), 'w') as test_source_bpe_f, \
open(os.path.join(output_dir, 'test.bpe.target'), 'w') as test_target_bpe_f:
if not only_squad:
for data in tqdm(newsqa['data']):
if data['type'] == 'train':
_process_data(data, 'newsqa', bpe, train_source_f, train_source_bpe_f, train_target_f, train_target_bpe_f)
elif data['type'] == 'dev':
_process_data(data, 'newsqa', bpe, val_source_f, val_source_bpe_f, val_target_f, val_target_bpe_f)
elif data['type'] == 'test':
_process_data(data, 'newsqa', bpe, test_source_f, test_source_bpe_f, test_target_f, test_target_bpe_f)
else:
print("data type error!")
print(data)
break
print("Done with NewsQA!")
print("Doing Squad now!")
data_types = ["validation", "train"]
for dtype in data_types:
if dtype == "validation":
input_file = "dev-v1.1.json"
elif dtype == "train":
input_file = "train-v1.1.json"
else:
print("ERROR! data split should be validation or train!")
with open(os.path.join(input_dir, input_file), 'r') as f_in:
data_dict = json.load(f_in)
if dtype == "train":
for data in tqdm(data_dict['data']):
_process_data(data, 'squad', bpe, train_source_f, train_source_bpe_f, train_target_f,
train_target_bpe_f)
elif dtype == "validation":
for data in data_dict['data']:
_process_data(data, 'squad', bpe, val_source_f, val_source_bpe_f, val_target_f, val_target_bpe_f)