in vilbert/datasets/retreival_dataset.py [0:0]
def __getitem__(self, index):
entry = self._entries[index]
image_id = entry["image_id"]
features, num_boxes, boxes, _ = self._image_features_reader[image_id]
mix_num_boxes = min(int(num_boxes), self._max_region_num)
mix_boxes_pad = np.zeros((self._max_region_num, 5))
mix_features_pad = np.zeros((self._max_region_num, 2048))
image_mask = [1] * (int(mix_num_boxes))
while len(image_mask) < self._max_region_num:
image_mask.append(0)
mix_boxes_pad[:mix_num_boxes] = boxes[:mix_num_boxes]
mix_features_pad[:mix_num_boxes] = features[:mix_num_boxes]
features1 = torch.tensor(mix_features_pad).float()
image_mask1 = torch.tensor(image_mask).long()
spatials1 = torch.tensor(mix_boxes_pad).float()
caption1 = entry["token"]
input_mask1 = entry["input_mask"]
segment_ids1 = entry["segment_ids"]
# negative samples.
# 1: correct one, 2: random caption wrong, 3: random image wrong. 4: hard image wrong.
while True:
# sample a random image:
img_id2 = random.choice(self.image_id_list)
if img_id2 != image_id:
break
entry2 = self._entries[random.choice(self.imgid2entry[img_id2])]
features2 = features1
image_mask2 = image_mask1
spatials2 = spatials1
caption2 = entry2["token"]
input_mask2 = entry2["input_mask"]
segment_ids2 = entry2["segment_ids"]
# random image wrong
while True:
# sample a random image:
img_id3 = random.choice(self.image_id_list)
if img_id3 != image_id:
break
features3, num_boxes3, boxes3, _ = self._image_features_reader[img_id3]
image_mask3 = [1] * (int(num_boxes3))
mix_num_boxes3 = min(int(num_boxes3), self._max_region_num)
mix_boxes_pad3 = np.zeros((self._max_region_num, 5))
mix_features_pad3 = np.zeros((self._max_region_num, 2048))
while len(image_mask3) < self._max_region_num:
image_mask3.append(0)
mix_boxes_pad[:mix_num_boxes3] = boxes3[:mix_num_boxes3]
mix_features_pad[:mix_num_boxes3] = features3[:mix_num_boxes3]
features3 = torch.tensor(mix_features_pad).float()
image_mask3 = torch.tensor(image_mask3).long()
spatials3 = torch.tensor(mix_boxes_pad).float()
caption3 = caption1
input_mask3 = input_mask1
segment_ids3 = segment_ids1
if self._split == "train":
# random hard caption.
rand_img_id_pool = self.train_hard_pool[self.train_imgId2pool[image_id]]
pool_img_idx = int(
rand_img_id_pool[np.random.randint(1, len(rand_img_id_pool))]
)
img_id4 = self.train_image_list[pool_img_idx]
else:
while True:
# sample a random image:
img_id4 = random.choice(self.image_id_list)
if img_id4 != image_id:
break
entry4 = self._entries[random.choice(self.imgid2entry[img_id4])]
features4 = features1
image_mask4 = image_mask1
spatials4 = spatials1
caption4 = entry4["token"]
input_mask4 = entry4["input_mask"]
segment_ids4 = entry4["segment_ids"]
features = torch.stack([features1, features2, features3, features4], dim=0)
spatials = torch.stack([spatials1, spatials2, spatials3, spatials4], dim=0)
image_mask = torch.stack(
[image_mask1, image_mask2, image_mask3, image_mask4], dim=0
)
caption = torch.stack([caption1, caption2, caption3, caption4], dim=0)
input_mask = torch.stack(
[input_mask1, input_mask2, input_mask3, input_mask4], dim=0
)
segment_ids = torch.stack(
[segment_ids1, segment_ids2, segment_ids3, segment_ids4], dim=0
)
co_attention_mask = torch.zeros((4, self._max_region_num, self._max_seq_length))
target = 0
return (
features,
spatials,
image_mask,
caption,
target,
input_mask,
segment_ids,
co_attention_mask,
image_id,
)