def preprecess_QA_generation_newsqa_squad()

in preprocess/data_prepro_clean.py [0:0]


def preprecess_QA_generation_newsqa_squad(input_dir,
                                          output_dir,
                                          encoder_json="/home/ec2-user/fairseq/encoder.json",
                                          vocab_bpe="/home/ec2-user/fairseq/vocab.bpe",
                                          only_squad=False):
    # use '50009' for the special dictionary token to separate question and answers since
    # this token is not encountered in bpe outputs

    def _process_data(d, data_source, bpe, source_f, source_bpe_f, target_f, target_bpe_f):
        if data_source == 'newsqa':
            source = d['text'].strip()
            for q in d['questions']:
                if 'consensus' in q and 'q' in q and 's' in q['consensus']:
                    question = q['q'].strip()
                    answer_s = q['consensus']['s']
                    answer_e = q['consensus']['e']
                    answer = source[answer_s:answer_e].strip()
                    truncated_source_bpe, truncated_source, question_answer_bpe = \
                        _format_question_answers_bpe(bpe, source, question, answer, special_token_id)

                    if truncated_source is None or answer_e >= len(truncated_source): # skip the question as answer span was truncated in source
                        continue
                    source_f.write(truncated_source.encode('unicode-escape').decode().replace('\\\\', '\\') + '\n')
                    source_bpe_f.write(' '.join(map(str, truncated_source_bpe)) + '\n')
                    target_f.write(bpe.decode(question_answer_bpe) + '\n')
                    target_bpe_f.write(' '.join(map(str, question_answer_bpe)) + '\n')
        elif data_source == 'squad':
            for paragraph in d['paragraphs']:
                context = paragraph['context']
                for qa in paragraph['qas']:
                    question = qa['question'].strip()
                    ans_set = set()
                    for ans in qa['answers']:
                        if ans['text'] not in ans_set:
                            ans_set.add(ans['text'])
                            truncated_source_bpe, truncated_source, question_answer_bpe = \
                                _format_question_answers_bpe(bpe, context, question, ans['text'], special_token_id)

                            if truncated_source is None:  # skip the question
                                continue
                            source_f.write(
                                truncated_source.encode('unicode-escape').decode().replace('\\\\', '\\') + '\n')
                            source_bpe_f.write(' '.join(map(str, truncated_source_bpe)) + '\n')
                            target_f.write(bpe.decode(question_answer_bpe) + '\n')
                            target_bpe_f.write(' '.join(map(str, question_answer_bpe)) + '\n')
        else:
            raise Exception("data_source must be squad or newsqa!")

    special_token_id = 50009
    from fairseq.data.encoders.gpt2_bpe import get_encoder
    bpe = get_encoder(encoder_json, vocab_bpe)
    if not only_squad:
        input_json = os.path.join(input_dir, 'combined-newsqa-data-v1.json')
        with open(input_json, 'r') as f:
            newsqa = json.load(f)

    with open(os.path.join(output_dir, 'train.source'), 'w') as train_source_f, \
            open(os.path.join(output_dir, 'train.target'), 'w') as train_target_f, \
            open(os.path.join(output_dir, 'train.bpe.source'), 'w') as train_source_bpe_f, \
            open(os.path.join(output_dir, 'train.bpe.target'), 'w') as train_target_bpe_f, \
            open(os.path.join(output_dir, 'val.source'), 'w') as val_source_f, \
            open(os.path.join(output_dir, 'val.target'), 'w') as val_target_f, \
            open(os.path.join(output_dir, 'val.bpe.source'), 'w') as val_source_bpe_f, \
            open(os.path.join(output_dir, 'val.bpe.target'), 'w') as val_target_bpe_f, \
            open(os.path.join(output_dir, 'test.source'), 'w') as test_source_f, \
            open(os.path.join(output_dir, 'test.target'), 'w') as test_target_f, \
            open(os.path.join(output_dir, 'test.bpe.source'), 'w') as test_source_bpe_f, \
            open(os.path.join(output_dir, 'test.bpe.target'), 'w') as test_target_bpe_f:

        if not only_squad:
            for data in tqdm(newsqa['data']):
                if data['type'] == 'train':
                    _process_data(data, 'newsqa', bpe, train_source_f, train_source_bpe_f, train_target_f, train_target_bpe_f)
                elif data['type'] == 'dev':
                    _process_data(data, 'newsqa', bpe, val_source_f, val_source_bpe_f, val_target_f, val_target_bpe_f)
                elif data['type'] == 'test':
                    _process_data(data, 'newsqa', bpe, test_source_f, test_source_bpe_f, test_target_f, test_target_bpe_f)
                else:
                    print("data type error!")
                    print(data)
                    break

            print("Done with NewsQA!")

        print("Doing Squad now!")
        data_types = ["validation", "train"]
        for dtype in data_types:
            if dtype == "validation":
                input_file = "dev-v1.1.json"
            elif dtype == "train":
                input_file = "train-v1.1.json"
            else:
                print("ERROR! data split should be validation or train!")

            with open(os.path.join(input_dir, input_file), 'r') as f_in:
                data_dict = json.load(f_in)
            if dtype == "train":
                for data in tqdm(data_dict['data']):
                    _process_data(data, 'squad', bpe, train_source_f, train_source_bpe_f, train_target_f,
                                  train_target_bpe_f)
            elif dtype == "validation":
                for data in data_dict['data']:
                    _process_data(data, 'squad', bpe, val_source_f, val_source_bpe_f, val_target_f, val_target_bpe_f)