def forward_turn()

in src/model/guesser_vilbert.py [0:0]


    def forward_turn(
        self, 
        q, 
        ans, 
        cats, 
        img_feats,
        bboxs, 
        curr_state=None,
        bboxs_mask=None,
        token_type_ids=None,
        attention_mask=None,
        image_attention_mask=None,
        co_attention_mask=None,
        decode_mask=None,
        task_ids=None,
        output_all_encoded_layers=False,
        output_all_attention_masks=False,
        update_vilbert=True,
        ):
        if update_vilbert:
            seq_out_txt, seq_out_vis, pooled_out_txt, pooled_out_vis, _ = self.bert(
                q,
                img_feats,
                bboxs,
                token_type_ids,
                attention_mask,
                image_attention_mask,
                co_attention_mask,
                task_ids,
                output_all_encoded_layers=output_all_encoded_layers,
                output_all_attention_masks=output_all_attention_masks,
            )
        else:
            with torch.no_grad():
                # output of each  token, output of each object, output of cls token, output of global vis
                seq_out_txt, seq_out_vis, pooled_out_txt, pooled_out_vis, _ = self.bert(
                    q,
                    img_feats,
                    bboxs,
                    token_type_ids,
                    attention_mask,
                    image_attention_mask,
                    co_attention_mask,
                    task_ids,
                    output_all_encoded_layers=output_all_encoded_layers,
                    output_all_attention_masks=output_all_attention_masks,
                )
        ans = self.ans_embed(ans)
        cats = self.cat_embed(cats) if self.use_category else None
        if curr_state is None:
            curr_state = self.init_state(
                img_feats.size(0), img_feats.size(1), img_feats.device)
        stat, logits = self.compute_next_state(
            curr_state, seq_out_vis, pooled_out_txt, ans, cats, bboxs_mask)
        return stat, logits