in src/model/self_play_qgen_vdst_oracle_vilbert.py [0:0]
def play_with_gt_questions(
self,
qs,
q_len,
image_features_rcnn_oracle,
bboxs_rcnn_oracle,
tgt_cat,
tgt_bbox,
tgt_img_feat,
cats,
bboxs_gt_gw,
bboxs_mask,
sos_token,
pad_token,
eoq_token,
eod_token,
answer2id,
answer2token,
):
device = qs.device
batch_size = qs.size(0)
max_turns = qs.size(1)
end_of_dialog = torch.zeros(batch_size).bool().to(device)
sos = torch.zeros(batch_size, 1).fill_(sos_token).long().to(device)
dialog = [torch.LongTensor(0).to(device) for _ in range(batch_size)]
q_log = [[] for _ in range(batch_size)]
a_log = [[] for _ in range(batch_size)]
a_conf_log = [[] for _ in range(batch_size)]
for turn in range(max_turns):
q_t = qs[:, turn]
q_len_t = q_len[:, turn]
# HACK: length == 0 can not forward in RNN
end_of_dialog = q_len_t == 0
fake_q_len = q_len_t.clone()
fake_q_len[fake_q_len == 0] = 1
# pad_q = pad_sequence(q, batch_first=True, padding_value=pad_token)
oracle_input_q = torch.cat([sos, q_t], dim=-1)
# For oracle vilbert
# +1 : [CLS] token
# txt_attn_mask = [[1] * ql + [0] * (qs.size(-1)+1 - ql) for ql in q_len_t]
txt_attn_mask = [[1] * ql + [0] * (qs.size(-1)+1 - ql) for ql in q_len_t]
txt_attn_mask = torch.tensor(txt_attn_mask).to(device)
# a = self.oracle(pad_q, tgt_cat, tgt_bbox, tgt_img_feat, fake_q_len)
a = self.oracle(
oracle_input_q,
tgt_cat,
tgt_bbox,
tgt_img_feat,
bboxs_rcnn_oracle,
image_features_rcnn_oracle,
fake_q_len,
attention_mask=txt_attn_mask
)
a_confidence = nn.functional.softmax(a, dim=-1)
a_idx = a.argmax(dim=-1)
a = oracle_output_to_answer_token(a_idx, answer2id, answer2token)
for b in range(batch_size):
if not end_of_dialog[b]:
_q = q_t[b][:q_len_t[b]]
q_log[b].append(_q)
dialog[b] = torch.cat([dialog[b], _q])
# if not end_of_dialog_next[b]:
_a = a[b].view(-1)
a_log[b].append(_a)
a_conf_log[b].append(a_confidence[b, a_idx[b]])
dialog[b] = torch.cat([dialog[b], _a])
if end_of_dialog.sum() == batch_size:
break
dial_len = torch.LongTensor([len(dial) for dial in dialog]).to(device)
dial_pad = pad_sequence(dialog, batch_first=True, padding_value=pad_token)
guess = self.guesser(dial_pad, dial_len, cats, bboxs_gt_gw, bboxs_mask)
return guess, dialog, q_log, a_log, a_conf_log