in src/utils_fusion_in_decoder.py [0:0]
def _create_examples(self, lines, set_type, passage_type, enable_sql_supervision, cand_for_each_source):
"""Creates examples for the training and dev sets."""
examples = []
total_answerable = 0
for (i, line) in tqdm(enumerate(lines), total=len(lines)):
guid = "%s-%s-%s" % (set_type, line['qid'], str(i))
src = []
has_pos = 0
if passage_type == 'both' and set_type != 'train':
# rank by bert score in val and test
score_key = 'rank_score'
# score_key = "nr_score"
both_sources = line["passages"][:2*cand_for_each_source] + line["tables"][:2*cand_for_each_source]
all_scores = np.array([x[score_key] for x in both_sources])
sorted_index = all_scores.argsort()[::-1][:2*cand_for_each_source]
for idx in sorted_index:
pas = both_sources[idx]
if idx < 2 * cand_for_each_source:
pas_type = 'passage'
title = pas["article_title"]
else:
pas_type = 'table'
title = "table_" + pas["uid"].split('-split')[0]
src.append("question: " + line["question"].strip() + f" </s> {pas_type} title: " + title + f" </s> {pas_type} content: " + pas["text"] + " </s>")
if isinstance(pas['judge'], dict):
has_pos += int(pas['judge']['judge_contain_all'])
else:
has_pos += int(pas['judge'])
else:
if passage_type in ['textual', 'hybrid', 'both']:
for pas in line["passages"][:cand_for_each_source]:
# 10 passages
src.append("question: " + line["question"].strip() + " </s> passage title: " + pas["article_title"] + " </s> passage content: " + pas["text"] + " </s>") # e.g. question: Tell me what the notes are for South Australia </s> passage title: Strictly Commercial </s> passage content: album \"ZAPPAtite\". All songs written and performed ... </s>
if isinstance(pas['judge'], dict):
has_pos += int(pas['judge']['judge_contain_all'])
else:
has_pos += int(pas['judge'])
if passage_type in ['tabular', 'hybrid', 'both']:
for tab in line["tables"][:cand_for_each_source]:
# 10 tables
src.append("question: " + line["question"].strip() + " </s> table title: " + "table_" + tab["uid"].split('-split')[0] + " </s> table content: " + tab["text"] + " </s>") # e.g. question: Tell me what the notes are for South Australia </s> table title: From Nashville to Memphis: The Essential '60s Masters ; Disc Two ; Disc Tw </s> table content: ... </s>
if isinstance(tab['judge'], dict):
has_pos += int(tab['judge']['judge_contain_all'])
else:
has_pos += int(tab['judge'])
const = 2 if passage_type in ['hybrid', 'both'] else 1
if len(src) != const * cand_for_each_source:
logger.info(line)
src = src + [""] * (const * cand_for_each_source - len(src))
if has_pos > 0:
total_answerable += 1
if enable_sql_supervision and 'true_sql' in line:
tgt = "sql: " + line["true_sql"] + " </s>" # e.g. sql: SELECT Position FROM table_1-10015132-11 WHERE School/Club Team = \"Butler CC (KS)\" </s>
examples.append(InputExamples(guid=guid, source=src, target=tgt))
if i % 1000 == 0:
logger.info(src[0] + " " + tgt)
if 'denotation' in line:
tgt = "answer: " + str(line["denotation"][0]) + " </s>"
examples.append(InputExamples(guid=guid, source=src, target=tgt))
if i % 1000 == 0:
logger.info(src[0] + " " + tgt)
if 'answers' in line:
tgt = "answer: " + str(line["answers"][0]) + " </s>" # e.g. answer: no slogan on current series </s>
examples.append(InputExamples(guid=guid, source=src, target=tgt))
if i % 1000 == 0:
logger.info(src[0] + " " + tgt + "\n")
logger.info(f"Total answerable in {set_type} split is {total_answerable / len(lines)}")
return examples