def __getitem__()

in vilbert/datasets/visdial_dataset.py [0:0]


    def __getitem__(self, index):

        entry = self._entries[index]
        image_id = entry["image_id"]
        features, num_boxes, boxes, _ = self._image_features_reader[image_id]
        image_mask = [1] * (int(num_boxes))
        while len(image_mask) < self._max_region_num:
            image_mask.append(0)

        features = torch.tensor(features).float()
        image_mask = torch.tensor(image_mask).long()
        spatials = torch.tensor(boxes).float()

        # Let's sample one dialog at a time.
        caption = self._captions[entry["caption"]]

        input_ids_all = []
        input_mask_all = []
        segment_ids_all = []

        for rnd in range(10):
            ques = self._questions[entry["dialog"][rnd]["question"]]
            # fact is all previous question+answer
            tokens_fact = []
            for j in range(rnd):
                if rnd - self.max_round_num <= j:
                    fact_q = self._questions[entry["dialog"][j]["question"]]
                    fact_a = self._answers[entry["dialog"][j]["answer"]]
                    if len(tokens_fact) == 0:
                        tokens_fact = tokens_fact + fact_q + [self.SEP] + fact_a
                    else:
                        tokens_fact = (
                            tokens_fact + [self.SEP] + fact_q + [self.SEP] + fact_a
                        )

            token_q = ques

            if len(tokens_fact) == 0:
                tokens_f = caption
            else:
                tokens_f = tokens_fact + [self.SEP] + caption
            answer_candidate = []
            answer_candidate.append(entry["dialog"][rnd]["gt_index"])
            rand_idx = np.random.permutation(self.ans_option)
            count = 0
            while len(answer_candidate) < self.max_num_option:
                if rand_idx[count] != entry["dialog"][rnd]["gt_index"]:
                    answer_candidate.append(rand_idx[count])
                count += 1

            input_ids_rnd = []
            input_mask_rnd = []
            segment_ids_rnd = []

            for i, ans_idx in enumerate(answer_candidate):
                tokens_a = self._answers[
                    entry["dialog"][rnd]["answer_options"][ans_idx]
                ]
                tokens_f_new = self._truncate_seq(
                    copy.deepcopy(tokens_f),
                    self._total_seq_length - len(token_q) - len(tokens_a) - 4,
                )

                tokens = []
                segment_ids = []

                tokens.append(self.CLS)
                segment_ids.append(0)
                for token in token_q:
                    tokens.append(token)
                    segment_ids.append(0)

                tokens.append(self.SEP)
                segment_ids.append(0)

                for token in tokens_a:
                    tokens.append(token)
                    segment_ids.append(1)

                tokens.append(self.SEP)
                segment_ids.append(1)

                for token in tokens_f_new:
                    tokens.append(token)
                    segment_ids.append(0)

                tokens.append(self.SEP)
                segment_ids.append(0)

                input_mask = [1] * (len(tokens))
                # Zero-pad up to the sequence length.
                while len(tokens) < self._total_seq_length:
                    tokens.append(0)
                    input_mask.append(0)
                    segment_ids.append(0)

                input_ids_rnd.append(tokens)
                input_mask_rnd.append(input_mask)
                segment_ids_rnd.append(segment_ids)

            input_ids_all.append(input_ids_rnd)
            input_mask_all.append(input_mask_rnd)
            segment_ids_all.append(segment_ids_rnd)

        input_ids = torch.from_numpy(np.array(input_ids_all))
        input_mask = torch.from_numpy(np.array(input_mask_all))
        segment_ids = torch.from_numpy(np.array(segment_ids_all))
        co_attention_mask = torch.zeros(
            (10, self.max_num_option, self._max_region_num, self._total_seq_length)
        )
        target = torch.zeros(10).long()
        return (
            features,
            spatials,
            image_mask,
            input_ids,
            target,
            input_mask,
            segment_ids,
            co_attention_mask,
            image_id,
        )