in pytext/data/sources/squad.py [0:0]
def process_squad_tsv(self, fname):
# Process SQUAD TSV for KD
if not fname:
print("Empty file name!")
return
field_names = [
"id1",
"doc",
"question",
"answers",
"answer_starts",
"has_answer",
"id2",
"start_logits",
"end_logits",
"has_answer_logits",
"pad_mask",
"segment_labels",
]
tsv_file = SafeFileWrapper(
get_absolute_path(fname), encoding="utf-8", errors="replace"
)
tsv = TSV(
tsv_file,
field_names=field_names,
delimiter=self.delimiter,
quoted=self.quoted,
drop_incomplete_rows=True,
)
for id, row in enumerate(tsv):
parts = (row[f] for f in field_names)
# All model output for KD are dumped using json serialization.
(
id1,
doc,
question,
answers,
answer_starts,
has_answer,
id2,
start_logits,
end_logits,
has_answer_logits,
pad_mask,
segment_labels,
) = (json.loads(s) for s in parts)
if isinstance(question, list):
# if we have paraphrases for question
question = choice(question)
for piece_dict in _split_document(
id,
doc,
question,
answers,
answer_starts,
has_answer == "True",
self.ignore_impossible,
self.max_character_length,
self.min_overlap,
):
piece_dict.update(
{
"start_logits": start_logits,
"end_logits": end_logits,
"has_answer_logits": has_answer_logits,
"pad_mask": pad_mask,
"segment_labels": segment_labels,
}
)
yield piece_dict