def __getitem__()

in vilbert/datasets/retreival_dataset.py [0:0]


    def __getitem__(self, index):
        entry = self._entries[index]
        image_id = entry["image_id"]

        features, num_boxes, boxes, _ = self._image_features_reader[image_id]

        mix_num_boxes = min(int(num_boxes), self._max_region_num)
        mix_boxes_pad = np.zeros((self._max_region_num, 5))
        mix_features_pad = np.zeros((self._max_region_num, 2048))

        image_mask = [1] * (int(mix_num_boxes))
        while len(image_mask) < self._max_region_num:
            image_mask.append(0)

        mix_boxes_pad[:mix_num_boxes] = boxes[:mix_num_boxes]
        mix_features_pad[:mix_num_boxes] = features[:mix_num_boxes]

        features1 = torch.tensor(mix_features_pad).float()
        image_mask1 = torch.tensor(image_mask).long()
        spatials1 = torch.tensor(mix_boxes_pad).float()

        caption1 = entry["token"]
        input_mask1 = entry["input_mask"]
        segment_ids1 = entry["segment_ids"]
        # negative samples.
        # 1: correct one, 2: random caption wrong, 3: random image wrong. 4: hard image wrong.

        while True:
            # sample a random image:
            img_id2 = random.choice(self.image_id_list)
            if img_id2 != image_id:
                break

        entry2 = self._entries[random.choice(self.imgid2entry[img_id2])]

        features2 = features1
        image_mask2 = image_mask1
        spatials2 = spatials1
        caption2 = entry2["token"]
        input_mask2 = entry2["input_mask"]
        segment_ids2 = entry2["segment_ids"]

        # random image wrong
        while True:
            # sample a random image:
            img_id3 = random.choice(self.image_id_list)
            if img_id3 != image_id:
                break

        features3, num_boxes3, boxes3, _ = self._image_features_reader[img_id3]
        image_mask3 = [1] * (int(num_boxes3))

        mix_num_boxes3 = min(int(num_boxes3), self._max_region_num)
        mix_boxes_pad3 = np.zeros((self._max_region_num, 5))
        mix_features_pad3 = np.zeros((self._max_region_num, 2048))

        while len(image_mask3) < self._max_region_num:
            image_mask3.append(0)

        mix_boxes_pad[:mix_num_boxes3] = boxes3[:mix_num_boxes3]
        mix_features_pad[:mix_num_boxes3] = features3[:mix_num_boxes3]

        features3 = torch.tensor(mix_features_pad).float()
        image_mask3 = torch.tensor(image_mask3).long()
        spatials3 = torch.tensor(mix_boxes_pad).float()

        caption3 = caption1
        input_mask3 = input_mask1
        segment_ids3 = segment_ids1

        if self._split == "train":
            # random hard caption.
            rand_img_id_pool = self.train_hard_pool[self.train_imgId2pool[image_id]]
            pool_img_idx = int(
                rand_img_id_pool[np.random.randint(1, len(rand_img_id_pool))]
            )
            img_id4 = self.train_image_list[pool_img_idx]
        else:
            while True:
                # sample a random image:
                img_id4 = random.choice(self.image_id_list)
                if img_id4 != image_id:
                    break

        entry4 = self._entries[random.choice(self.imgid2entry[img_id4])]

        features4 = features1
        image_mask4 = image_mask1
        spatials4 = spatials1
        caption4 = entry4["token"]
        input_mask4 = entry4["input_mask"]
        segment_ids4 = entry4["segment_ids"]

        features = torch.stack([features1, features2, features3, features4], dim=0)
        spatials = torch.stack([spatials1, spatials2, spatials3, spatials4], dim=0)
        image_mask = torch.stack(
            [image_mask1, image_mask2, image_mask3, image_mask4], dim=0
        )
        caption = torch.stack([caption1, caption2, caption3, caption4], dim=0)
        input_mask = torch.stack(
            [input_mask1, input_mask2, input_mask3, input_mask4], dim=0
        )
        segment_ids = torch.stack(
            [segment_ids1, segment_ids2, segment_ids3, segment_ids4], dim=0
        )
        co_attention_mask = torch.zeros((4, self._max_region_num, self._max_seq_length))
        target = 0

        return (
            features,
            spatials,
            image_mask,
            caption,
            target,
            input_mask,
            segment_ids,
            co_attention_mask,
            image_id,
        )