def __getitem__()

in mdr/qa/qa_dataset.py [0:0]


    def __getitem__(self, index):
        item = prepare(self.data[index], self.tokenizer) 
        context_ann = item["context_processed"]
        q_toks = self.tokenizer.tokenize(item["question"])[:self.max_q_len]
        para_offset = len(q_toks) + 2 # cls and seq
        item["wp_tokens"] = context_ann["all_doc_tokens"]
        assert item["wp_tokens"][0] == "yes" and item["wp_tokens"][1] == "no"
        item["para_offset"] = para_offset
        max_toks_for_doc = self.max_seq_len - para_offset - 1
        if len(item["wp_tokens"]) > max_toks_for_doc:
            item["wp_tokens"] = item["wp_tokens"][:max_toks_for_doc]
        item["encodings"] = self.tokenizer.encode_plus(q_toks, text_pair=item["wp_tokens"], max_length=self.max_seq_len, return_tensors="pt", is_pretokenized=True)

        item["paragraph_mask"] = torch.zeros(item["encodings"]["input_ids"].size()).view(-1)
        item["paragraph_mask"][para_offset:-1] = 1
        
        if self.train:
            # if item["label"] == 1:
            if item["ans_covered"]:
                if item["gold_answer"][0] == "yes":
                    # ans_type = 0
                    starts, ends= [para_offset], [para_offset]
                elif item["gold_answer"][0] == "no":
                    # ans_type = 1
                    starts, ends= [para_offset + 1], [para_offset + 1]
                else:
                    # ans_type = 2
                    matched_spans = match_answer_span(context_ann["context"], item["gold_answer"], self.simple_tok)
                    ans_starts, ans_ends= [], []
                    for span in matched_spans:
                        char_starts = [i for i in range(len(context_ann["context"])) if context_ann["context"].startswith(span, i)]
                        if len(char_starts) > 0:
                            char_ends = [start + len(span) - 1 for start in char_starts]
                            answer = {"text": span, "char_spans": list(zip(char_starts, char_ends))}
                            ans_spans = find_ans_span_with_char_offsets(
                            answer, context_ann["char_to_word_offset"], context_ann["doc_tokens"], context_ann["all_doc_tokens"], context_ann["orig_to_tok_index"], self.tokenizer)
                            for s, e in ans_spans:
                                ans_starts.append(s)
                                ans_ends.append(e)
                    starts, ends = [], []
                    for s, e in zip(ans_starts, ans_ends):
                        if s >= len(item["wp_tokens"]):
                            continue
                        else:
                            s = min(s, len(item["wp_tokens"]) - 1) + para_offset
                            e = min(e, len(item["wp_tokens"]) - 1) + para_offset
                            starts.append(s)
                            ends.append(e)
                    if len(starts) == 0:
                        starts, ends = [-1], [-1]         
            else:
                starts, ends= [-1], [-1]
                # ans_type = -1
                        
            item["starts"] = torch.LongTensor(starts)
            item["ends"] = torch.LongTensor(ends)
            # item["ans_type"] = torch.LongTensor([ans_type])

            if item["label"]:
                assert len(item["sp_sent_labels"]) == len(item["context_processed"]["sent_starts"])
        else:
            #     # for answer extraction
            item["doc_tokens"] = context_ann["doc_tokens"]
            item["tok_to_orig_index"] = context_ann["tok_to_orig_index"]

        # filter sentence offsets exceeding max sequence length
        sent_labels, sent_offsets = [], []
        for idx, s in enumerate(item["context_processed"]["sent_starts"]):
            if s >= len(item["wp_tokens"]):
                break
            if "sp_sent_labels" in item:
                sent_labels.append(item["sp_sent_labels"][idx])
            sent_offsets.append(s + para_offset)
            assert item["encodings"]["input_ids"].view(-1)[s+para_offset] == self.tokenizer.convert_tokens_to_ids("[unused1]")

        # supporting fact label
        item["sent_offsets"] = sent_offsets
        item["sent_offsets"] = torch.LongTensor(item["sent_offsets"])
        if self.train:
            item["sent_labels"] = sent_labels if len(sent_labels) != 0 else [0] * len(sent_offsets)
            item["sent_labels"] = torch.LongTensor(item["sent_labels"])
            item["ans_covered"] = torch.LongTensor([item["ans_covered"]])

        item["label"] = torch.LongTensor([item["label"]])
        return item