in src/model/self_play_all_vilbert.py [0:0]
def play(
self,
qgen_img_feats,
qgen_bboxs,
image_features_rcnn_oracle,
bboxs_rcnn_oracle,
image_features_rcnn_gt_guesser,
bboxs_rcnn_gt_guesser,
tgt_cat,
tgt_bbox,
tgt_img_feat,
cats_guesser,
bboxs_mask,
sos_token,
pad_token,
eoq_token,
eod_token,
answer2id,
answer2token,
max_q_len,
greedy=True,
max_turns=8
):
device = qgen_img_feats.device
batch_size = qgen_img_feats.size(0)
num_bboxs = qgen_img_feats.size(1)
end_of_dialog = torch.zeros(batch_size).bool().to(device)
last_wrd = torch.zeros(batch_size).fill_(sos_token).long().to(device)
sos = torch.zeros(batch_size, 1).fill_(sos_token).long().to(device)
last_state = None
# pi = (torch.ones(batch_size, num_bboxs) / num_bboxs).to(device)
pi = self.qgen.state_handler.init_state(batch_size, num_bboxs, device)
dialog = [torch.LongTensor(0).to(device) for _ in range(batch_size)]
q_log = [[] for _ in range(batch_size)]
a_log = [[] for _ in range(batch_size)]
a_conf_log = [[] for _ in range(batch_size)]
guesser_state = self.guesser.init_state(
batch_size, image_features_rcnn_gt_guesser.size(1), device)
guesser_final_logits = torch.zeros_like(guesser_state)
logits = torch.ones_like(guesser_final_logits)
for turn in range(max_turns):
q, q_len, state, end_of_dialog_next = self.qgen.generate_sentence(
last_wrd, qgen_img_feats, qgen_bboxs, eoq_token, eod_token, end_of_dialog,
max_q_len=max_q_len, pi=pi, last_state=last_state, greedy=greedy
)
pad_q = pad_sequence(q, batch_first=True, padding_value=pad_token)
q_plus_cls_token = torch.cat([sos, pad_q], dim=-1)
# For oracle vilbert
# +1 : [CLS] token
txt_attn_mask = [[1] * (ql+1) + [0] * (q_len.max().item() - ql) for ql in q_len]
txt_attn_mask = torch.tensor(txt_attn_mask).to(device)
# a = self.oracle(pad_q, tgt_cat, tgt_bbox, tgt_img_feat, fake_q_len)
a = self.oracle(
q_plus_cls_token,
tgt_cat,
tgt_bbox,
tgt_img_feat,
bboxs_rcnn_oracle,
image_features_rcnn_oracle,
attention_mask=txt_attn_mask
)
a_confidence = nn.functional.softmax(a, dim=-1)
a_idx = a.argmax(dim=-1)
for b in range(batch_size):
if end_of_dialog_next[b]:
if guesser_final_logits[b].sum() == 0:
# last turn
guesser_final_logits[b] = logits[b]
guesser_state, logits = self.guesser.forward_turn(
q_plus_cls_token,
a_idx,
cats_guesser,
image_features_rcnn_gt_guesser,
bboxs_rcnn_gt_guesser,
curr_state=guesser_state,
bboxs_mask=bboxs_mask,
attention_mask=txt_attn_mask,
image_attention_mask=bboxs_mask.long(),
)
pi, _ = self.qgen.state_handler.forward_turn(
q_plus_cls_token,
a_idx,
None, # cats
qgen_img_feats,
qgen_bboxs,
curr_state=pi,
attention_mask=txt_attn_mask,
update_vilbert=False,
)
a = oracle_output_to_answer_token(a_idx, answer2id, answer2token)
for b in range(batch_size):
if not end_of_dialog[b]:
_q = q[b][:q_len[b]]
q_log[b].append(_q)
dialog[b] = torch.cat([dialog[b], _q])
if not end_of_dialog_next[b]:
_a = a[b].view(-1)
a_log[b].append(_a)
a_conf_log[b].append(a_confidence[b, a_idx[b]])
dialog[b] = torch.cat([dialog[b], _a])
if end_of_dialog_next.sum().item() == batch_size:
break
end_of_dialog = end_of_dialog_next
last_wrd = a
last_state = state
# pi = self.qgen.refresh_pi(pi, a, last_state[0,0], obj_repr, input_token=True)
for b in range(batch_size):
if guesser_final_logits[b].sum() == 0:
guesser_final_logits[b] = logits[b]
# First one is global image token
guess = guesser_final_logits[:, 1:]
return guess, dialog, q_log, a_log, a_conf_log