src/model/self_play_qgen_vdst_guesser_vilbert.py [31:130]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        getattr(self, player).load_state_dict(
            torch.load(path, map_location=map_location)['model']
        )
        return "Load %s from %s" % (player, path)


    def play_with_gt_questions(
        self,
        qs, 
        q_len,
        image_features_rcnn_oracle,
        bboxs_rcnn_oracle,
        image_features_rcnn_gt_guesser,
        bboxs_rcnn_gt_guesser,
        tgt_cat,
        tgt_bbox,
        tgt_img_feat, 
        cats_guesser, 
        bboxs_mask, 
        sos_token, 
        pad_token, 
        eoq_token, 
        eod_token, 
        answer2id, 
        answer2token, 
        ):
        device = qs.device
        batch_size = qs.size(0)
        max_turns = qs.size(1)
        end_of_dialog = torch.zeros(batch_size).bool().to(device)
        sos = torch.zeros(batch_size, 1).fill_(sos_token).long().to(device)

        dialog = [torch.LongTensor(0).to(device) for _ in range(batch_size)]
        q_log = [[] for _ in range(batch_size)]
        a_log = [[] for _ in range(batch_size)]
        a_conf_log = [[] for _ in range(batch_size)]

        guesser_state = self.guesser.init_state(
            batch_size, image_features_rcnn_gt_guesser.size(1), device)
        guesser_final_logits = torch.zeros_like(guesser_state)
        for turn in range(max_turns):
            q_t = qs[:, turn]
            q_len_t = q_len[:, turn] 
            # HACK: length == 0 can not forward in RNN
            end_of_dialog = q_len_t == 0
            # pad_q = pad_sequence(q, batch_first=True, padding_value=pad_token)
            q_plus_cls_token = torch.cat([sos, q_t], dim=-1)
            # For oracle vilbert
            # +1 : [CLS] token
            txt_attn_mask = [[1] * ql + [0] * (qs.size(-1)+1 - ql) for ql in q_len_t]
            txt_attn_mask = torch.tensor(txt_attn_mask).to(device)
            # a = self.oracle(pad_q, tgt_cat, tgt_bbox, tgt_img_feat, fake_q_len)
            a = self.oracle(
                q_plus_cls_token, 
                tgt_cat, 
                tgt_bbox, 
                tgt_img_feat, 
                bboxs_rcnn_oracle, 
                image_features_rcnn_oracle,
                attention_mask=txt_attn_mask
                )
            a_confidence = nn.functional.softmax(a, dim=-1)
            a_idx = a.argmax(dim=-1)
            
            guesser_state, logits = self.guesser.forward_turn(
                q_plus_cls_token, 
                a_idx, 
                cats_guesser, 
                image_features_rcnn_gt_guesser,
                bboxs_rcnn_gt_guesser, 
                curr_state=guesser_state,
                bboxs_mask=bboxs_mask,
                attention_mask=txt_attn_mask,
                image_attention_mask=bboxs_mask.long(),
            )
            a = oracle_output_to_answer_token(a_idx, answer2id, answer2token)
            for b in range(batch_size):
                if not end_of_dialog[b]:
                    # QUESTION
                    _q = q_t[b][:q_len_t[b]]
                    q_log[b].append(_q)
                    dialog[b] = torch.cat([dialog[b], _q])
                    # ANSWER
                    _a = a[b].view(-1)
                    a_log[b].append(_a)
                    a_conf_log[b].append(a_confidence[b, a_idx[b]])
                    dialog[b] = torch.cat([dialog[b], _a])
                else:
                    if guesser_final_logits[b].sum() == 0:
                        # last turn
                        guesser_final_logits[b] = logits[b]
            if end_of_dialog.sum() == batch_size:
                break
        # First one is global image token
        guess = guesser_final_logits[:, 1:]
        return guess, dialog, q_log, a_log, a_conf_log

    def play(
        self, 
        obj_feats, 
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



src/model/self_play_qgen_vdst_oracle_vilbert_guesser_vilbert.py [44:143]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            getattr(self, player).load_state_dict(
                torch.load(path, map_location=map_location)['model']
            )
        return "Load %s from %s" % (player, path)


    def play_with_gt_questions(
        self,
        qs, 
        q_len,
        image_features_rcnn_oracle,
        bboxs_rcnn_oracle,
        image_features_rcnn_gt_guesser,
        bboxs_rcnn_gt_guesser,
        tgt_cat,
        tgt_bbox,
        tgt_img_feat, 
        cats_guesser, 
        bboxs_mask, 
        sos_token, 
        pad_token, 
        eoq_token, 
        eod_token, 
        answer2id, 
        answer2token, 
        ):
        device = qs.device
        batch_size = qs.size(0)
        max_turns = qs.size(1)
        end_of_dialog = torch.zeros(batch_size).bool().to(device)
        sos = torch.zeros(batch_size, 1).fill_(sos_token).long().to(device)

        dialog = [torch.LongTensor(0).to(device) for _ in range(batch_size)]
        q_log = [[] for _ in range(batch_size)]
        a_log = [[] for _ in range(batch_size)]
        a_conf_log = [[] for _ in range(batch_size)]

        guesser_state = self.guesser.init_state(
            batch_size, image_features_rcnn_gt_guesser.size(1), device)
        guesser_final_logits = torch.zeros_like(guesser_state)
        for turn in range(max_turns):
            q_t = qs[:, turn]
            q_len_t = q_len[:, turn] 
            # HACK: length == 0 can not forward in RNN
            end_of_dialog = q_len_t == 0
            # pad_q = pad_sequence(q, batch_first=True, padding_value=pad_token)
            q_plus_cls_token = torch.cat([sos, q_t], dim=-1)
            # For oracle vilbert
            # +1 : [CLS] token
            txt_attn_mask = [[1] * ql + [0] * (qs.size(-1)+1 - ql) for ql in q_len_t]
            txt_attn_mask = torch.tensor(txt_attn_mask).to(device)
            # a = self.oracle(pad_q, tgt_cat, tgt_bbox, tgt_img_feat, fake_q_len)
            a = self.oracle(
                q_plus_cls_token, 
                tgt_cat, 
                tgt_bbox, 
                tgt_img_feat, 
                bboxs_rcnn_oracle, 
                image_features_rcnn_oracle,
                attention_mask=txt_attn_mask
                )
            a_confidence = nn.functional.softmax(a, dim=-1)
            a_idx = a.argmax(dim=-1)
            
            guesser_state, logits = self.guesser.forward_turn(
                q_plus_cls_token, 
                a_idx, 
                cats_guesser, 
                image_features_rcnn_gt_guesser,
                bboxs_rcnn_gt_guesser, 
                curr_state=guesser_state,
                bboxs_mask=bboxs_mask,
                attention_mask=txt_attn_mask,
                image_attention_mask=bboxs_mask.long(),
            )
            a = oracle_output_to_answer_token(a_idx, answer2id, answer2token)
            for b in range(batch_size):
                if not end_of_dialog[b]:
                    # QUESTION
                    _q = q_t[b][:q_len_t[b]]
                    q_log[b].append(_q)
                    dialog[b] = torch.cat([dialog[b], _q])
                    # ANSWER
                    _a = a[b].view(-1)
                    a_log[b].append(_a)
                    a_conf_log[b].append(a_confidence[b, a_idx[b]])
                    dialog[b] = torch.cat([dialog[b], _a])
                else:
                    if guesser_final_logits[b].sum() == 0:
                        # last turn
                        guesser_final_logits[b] = logits[b]
            if end_of_dialog.sum() == batch_size:
                break
        # First one is global image token
        guess = guesser_final_logits[:, 1:]
        return guess, dialog, q_log, a_log, a_conf_log

    def play(
        self, 
        obj_feats, 
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



