in vilbert/datasets/flickr_grounding_dataset.py [0:0]
def __getitem__(self, index):
entry = self.entries[index]
image_id = entry["image_id"]
ref_box = entry["refBox"]
features, num_boxes, boxes, boxes_ori = self._image_features_reader[image_id]
boxes_ori = boxes_ori[:num_boxes]
boxes = boxes[:num_boxes]
features = features[:num_boxes]
if self.split == "train":
gt_features, gt_num_boxes, gt_boxes, gt_boxes_ori = self._gt_image_features_reader[
image_id
]
# merge two boxes, and assign the labels.
gt_boxes_ori = gt_boxes_ori[1:gt_num_boxes]
gt_boxes = gt_boxes[1:gt_num_boxes]
gt_features = gt_features[1:gt_num_boxes]
# concatenate the boxes
mix_boxes_ori = np.concatenate((boxes_ori, gt_boxes_ori), axis=0)
mix_boxes = np.concatenate((boxes, gt_boxes), axis=0)
mix_features = np.concatenate((features, gt_features), axis=0)
mix_num_boxes = min(
int(num_boxes + int(gt_num_boxes) - 1), self.max_region_num
)
# given the mix boxes, and ref_box, calculate the overlap.
mix_target = iou(
torch.tensor(mix_boxes_ori[:, :4]).float(),
torch.tensor([ref_box]).float(),
)
mix_target[mix_target < 0.5] = 0
else:
mix_boxes_ori = boxes_ori
mix_boxes = boxes
mix_features = features
mix_num_boxes = min(int(num_boxes), self.max_region_num)
mix_target = iou(
torch.tensor(mix_boxes_ori[:, :4]).float(),
torch.tensor([ref_box]).float(),
)
image_mask = [1] * (mix_num_boxes)
while len(image_mask) < self.max_region_num:
image_mask.append(0)
mix_boxes_pad = np.zeros((self.max_region_num, 5))
mix_features_pad = np.zeros((self.max_region_num, 2048))
mix_boxes_pad[:mix_num_boxes] = mix_boxes[:mix_num_boxes]
mix_features_pad[:mix_num_boxes] = mix_features[:mix_num_boxes]
# appending the target feature.
features = torch.tensor(mix_features_pad).float()
image_mask = torch.tensor(image_mask).long()
spatials = torch.tensor(mix_boxes_pad).float()
target = torch.zeros((self.max_region_num, 1)).float()
target[:mix_num_boxes] = mix_target[:mix_num_boxes]
spatials_ori = torch.tensor(mix_boxes_ori).float()
co_attention_mask = torch.zeros((self.max_region_num, self._max_seq_length))
caption = entry["token"]
input_mask = entry["input_mask"]
segment_ids = entry["segment_ids"]
return (
features,
spatials,
image_mask,
caption,
target,
input_mask,
segment_ids,
co_attention_mask,
image_id,
)