in src/model/guesser_vilbert.py [0:0]
def forward_turn(
self,
q,
ans,
cats,
img_feats,
bboxs,
curr_state=None,
bboxs_mask=None,
token_type_ids=None,
attention_mask=None,
image_attention_mask=None,
co_attention_mask=None,
decode_mask=None,
task_ids=None,
output_all_encoded_layers=False,
output_all_attention_masks=False,
update_vilbert=True,
):
if update_vilbert:
seq_out_txt, seq_out_vis, pooled_out_txt, pooled_out_vis, _ = self.bert(
q,
img_feats,
bboxs,
token_type_ids,
attention_mask,
image_attention_mask,
co_attention_mask,
task_ids,
output_all_encoded_layers=output_all_encoded_layers,
output_all_attention_masks=output_all_attention_masks,
)
else:
with torch.no_grad():
# output of each token, output of each object, output of cls token, output of global vis
seq_out_txt, seq_out_vis, pooled_out_txt, pooled_out_vis, _ = self.bert(
q,
img_feats,
bboxs,
token_type_ids,
attention_mask,
image_attention_mask,
co_attention_mask,
task_ids,
output_all_encoded_layers=output_all_encoded_layers,
output_all_attention_masks=output_all_attention_masks,
)
ans = self.ans_embed(ans)
cats = self.cat_embed(cats) if self.use_category else None
if curr_state is None:
curr_state = self.init_state(
img_feats.size(0), img_feats.size(1), img_feats.device)
stat, logits = self.compute_next_state(
curr_state, seq_out_vis, pooled_out_txt, ans, cats, bboxs_mask)
return stat, logits