in vilbert/datasets/visdial_dataset.py [0:0]
def __getitem__(self, index):
entry = self._entries[index]
image_id = entry["image_id"]
features, num_boxes, boxes, _ = self._image_features_reader[image_id]
image_mask = [1] * (int(num_boxes))
while len(image_mask) < self._max_region_num:
image_mask.append(0)
features = torch.tensor(features).float()
image_mask = torch.tensor(image_mask).long()
spatials = torch.tensor(boxes).float()
# Let's sample one dialog at a time.
caption = self._captions[entry["caption"]]
input_ids_all = []
input_mask_all = []
segment_ids_all = []
for rnd in range(10):
ques = self._questions[entry["dialog"][rnd]["question"]]
# fact is all previous question+answer
tokens_fact = []
for j in range(rnd):
if rnd - self.max_round_num <= j:
fact_q = self._questions[entry["dialog"][j]["question"]]
fact_a = self._answers[entry["dialog"][j]["answer"]]
if len(tokens_fact) == 0:
tokens_fact = tokens_fact + fact_q + [self.SEP] + fact_a
else:
tokens_fact = (
tokens_fact + [self.SEP] + fact_q + [self.SEP] + fact_a
)
token_q = ques
if len(tokens_fact) == 0:
tokens_f = caption
else:
tokens_f = tokens_fact + [self.SEP] + caption
answer_candidate = []
answer_candidate.append(entry["dialog"][rnd]["gt_index"])
rand_idx = np.random.permutation(self.ans_option)
count = 0
while len(answer_candidate) < self.max_num_option:
if rand_idx[count] != entry["dialog"][rnd]["gt_index"]:
answer_candidate.append(rand_idx[count])
count += 1
input_ids_rnd = []
input_mask_rnd = []
segment_ids_rnd = []
for i, ans_idx in enumerate(answer_candidate):
tokens_a = self._answers[
entry["dialog"][rnd]["answer_options"][ans_idx]
]
tokens_f_new = self._truncate_seq(
copy.deepcopy(tokens_f),
self._total_seq_length - len(token_q) - len(tokens_a) - 4,
)
tokens = []
segment_ids = []
tokens.append(self.CLS)
segment_ids.append(0)
for token in token_q:
tokens.append(token)
segment_ids.append(0)
tokens.append(self.SEP)
segment_ids.append(0)
for token in tokens_a:
tokens.append(token)
segment_ids.append(1)
tokens.append(self.SEP)
segment_ids.append(1)
for token in tokens_f_new:
tokens.append(token)
segment_ids.append(0)
tokens.append(self.SEP)
segment_ids.append(0)
input_mask = [1] * (len(tokens))
# Zero-pad up to the sequence length.
while len(tokens) < self._total_seq_length:
tokens.append(0)
input_mask.append(0)
segment_ids.append(0)
input_ids_rnd.append(tokens)
input_mask_rnd.append(input_mask)
segment_ids_rnd.append(segment_ids)
input_ids_all.append(input_ids_rnd)
input_mask_all.append(input_mask_rnd)
segment_ids_all.append(segment_ids_rnd)
input_ids = torch.from_numpy(np.array(input_ids_all))
input_mask = torch.from_numpy(np.array(input_mask_all))
segment_ids = torch.from_numpy(np.array(segment_ids_all))
co_attention_mask = torch.zeros(
(10, self.max_num_option, self._max_region_num, self._total_seq_length)
)
target = torch.zeros(10).long()
return (
features,
spatials,
image_mask,
input_ids,
target,
input_mask,
segment_ids,
co_attention_mask,
image_id,
)