in mdr/qa/qa_dataset.py [0:0]
def __getitem__(self, index):
item = prepare(self.data[index], self.tokenizer)
context_ann = item["context_processed"]
q_toks = self.tokenizer.tokenize(item["question"])[:self.max_q_len]
para_offset = len(q_toks) + 2 # cls and seq
item["wp_tokens"] = context_ann["all_doc_tokens"]
assert item["wp_tokens"][0] == "yes" and item["wp_tokens"][1] == "no"
item["para_offset"] = para_offset
max_toks_for_doc = self.max_seq_len - para_offset - 1
if len(item["wp_tokens"]) > max_toks_for_doc:
item["wp_tokens"] = item["wp_tokens"][:max_toks_for_doc]
item["encodings"] = self.tokenizer.encode_plus(q_toks, text_pair=item["wp_tokens"], max_length=self.max_seq_len, return_tensors="pt", is_pretokenized=True)
item["paragraph_mask"] = torch.zeros(item["encodings"]["input_ids"].size()).view(-1)
item["paragraph_mask"][para_offset:-1] = 1
if self.train:
# if item["label"] == 1:
if item["ans_covered"]:
if item["gold_answer"][0] == "yes":
# ans_type = 0
starts, ends= [para_offset], [para_offset]
elif item["gold_answer"][0] == "no":
# ans_type = 1
starts, ends= [para_offset + 1], [para_offset + 1]
else:
# ans_type = 2
matched_spans = match_answer_span(context_ann["context"], item["gold_answer"], self.simple_tok)
ans_starts, ans_ends= [], []
for span in matched_spans:
char_starts = [i for i in range(len(context_ann["context"])) if context_ann["context"].startswith(span, i)]
if len(char_starts) > 0:
char_ends = [start + len(span) - 1 for start in char_starts]
answer = {"text": span, "char_spans": list(zip(char_starts, char_ends))}
ans_spans = find_ans_span_with_char_offsets(
answer, context_ann["char_to_word_offset"], context_ann["doc_tokens"], context_ann["all_doc_tokens"], context_ann["orig_to_tok_index"], self.tokenizer)
for s, e in ans_spans:
ans_starts.append(s)
ans_ends.append(e)
starts, ends = [], []
for s, e in zip(ans_starts, ans_ends):
if s >= len(item["wp_tokens"]):
continue
else:
s = min(s, len(item["wp_tokens"]) - 1) + para_offset
e = min(e, len(item["wp_tokens"]) - 1) + para_offset
starts.append(s)
ends.append(e)
if len(starts) == 0:
starts, ends = [-1], [-1]
else:
starts, ends= [-1], [-1]
# ans_type = -1
item["starts"] = torch.LongTensor(starts)
item["ends"] = torch.LongTensor(ends)
# item["ans_type"] = torch.LongTensor([ans_type])
if item["label"]:
assert len(item["sp_sent_labels"]) == len(item["context_processed"]["sent_starts"])
else:
# # for answer extraction
item["doc_tokens"] = context_ann["doc_tokens"]
item["tok_to_orig_index"] = context_ann["tok_to_orig_index"]
# filter sentence offsets exceeding max sequence length
sent_labels, sent_offsets = [], []
for idx, s in enumerate(item["context_processed"]["sent_starts"]):
if s >= len(item["wp_tokens"]):
break
if "sp_sent_labels" in item:
sent_labels.append(item["sp_sent_labels"][idx])
sent_offsets.append(s + para_offset)
assert item["encodings"]["input_ids"].view(-1)[s+para_offset] == self.tokenizer.convert_tokens_to_ids("[unused1]")
# supporting fact label
item["sent_offsets"] = sent_offsets
item["sent_offsets"] = torch.LongTensor(item["sent_offsets"])
if self.train:
item["sent_labels"] = sent_labels if len(sent_labels) != 0 else [0] * len(sent_offsets)
item["sent_labels"] = torch.LongTensor(item["sent_labels"])
item["ans_covered"] = torch.LongTensor([item["ans_covered"]])
item["label"] = torch.LongTensor([item["label"]])
return item