in src/utils_fusion_in_decoder.py [0:0]
def get_data_examples(self, data_dir, mode, question_type, passage_type, enable_sql_supervision, cand_for_each_source):
"""See base class."""
if question_type == 'wikisql_question':
prefix = ['wikisql_denotation']
elif question_type == 'opensquad_question':
prefix = ['opensquad']
elif question_type == 'mixed':
prefix = ['wikisql_denotation', 'opensquad']
elif question_type == 'nq':
prefix = ['NQ-open']
elif question_type == 'nq_wikisql':
prefix = ['NQ-open', 'wikisql_denotation']
elif question_type == 'ott-qa':
prefix = ['ott-qa']
elif question_type == 'ottqa_wikisql':
prefix = ['ott-qa', 'wikisql_denotation']
elif question_type == 'all':
prefix = ['ott-qa', 'opensquad', 'NQ-open', 'wikisql_denotation']
else:
raise NotImplementedError()
data = []
if mode == 'train':
for q_type in prefix:
if q_type == 'wikisql_denotation':
data.extend(self._read_jsonl(os.path.join(data_dir, f"{q_type}.train.es_retrieved.true_sql.jsonl")))
else:
data.extend(self._read_jsonl(os.path.join(data_dir, f"{q_type}.train.es_retrieved.jsonl")))
random.shuffle(data)
return self._create_examples(data, mode, passage_type, enable_sql_supervision, cand_for_each_source)
else:
for q_type in prefix:
if q_type == 'wikisql_denotation':
data.extend(self._read_jsonl(os.path.join(data_dir, f"scores", f"{q_type}.dev.es_retrieved.scores.sorted.true_sql.jsonl")))
else:
data.extend(self._read_jsonl(os.path.join(data_dir, f"scores", f"{q_type}.dev.es_retrieved.processed.scores.sorted.jsonl")))
# data.extend(self._read_jsonl(os.path.join(data_dir, "nq.dev.wikigq_siamese_512.wiki2016.wikitable.jsonl")))
# data.extend(self._read_jsonl(os.path.join(data_dir, "scores", f"{q_type}.dev.es_retrieved.scores.sorted.jsonl")))
return self._create_examples(data, mode, passage_type, enable_sql_supervision, cand_for_each_source)