def play()

in src/model/self_play_qgen_vdst_oracle_vilbert.py [0:0]


    def play(
        self, 
        obj_feats, 
        image_features_rcnn_oracle, 
        bboxs_rcnn_oracle, 
        tgt_cat, 
        tgt_bbox, 
        tgt_img_feat, 
        cats, 
        bboxs_gt_gw,
        bboxs_mask, 
        sos_token, 
        pad_token, 
        eoq_token, 
        eod_token, 
        answer2id, 
        answer2token, 
        max_q_len, 
        greedy=True, 
        max_turns=8
        ):
        device = obj_feats.device
        batch_size = obj_feats.size(0)
        num_bboxs = obj_feats.size(1)
        end_of_dialog = torch.zeros(batch_size).bool().to(device)
        last_wrd = torch.zeros(batch_size).fill_(sos_token).long().to(device)
        sos = torch.zeros(batch_size, 1).fill_(sos_token).long().to(device)
        last_state = None
        pi = (torch.ones(batch_size, num_bboxs) / num_bboxs).to(device)

        dialog = [torch.LongTensor(0).to(device) for _ in range(batch_size)]
        q_log = [[] for _ in range(batch_size)]
        a_log = [[] for _ in range(batch_size)]
        a_conf_log = [[] for _ in range(batch_size)]
        for turn in range(max_turns):
            q, q_len, state, obj_repr, end_of_dialog_next = self.qgen.generate_sentence(
                last_wrd, obj_feats, eoq_token, eod_token, end_of_dialog, 
                max_q_len=max_q_len, pi=pi, last_state=last_state, greedy=greedy
            )

            
            pad_q = pad_sequence(q, batch_first=True, padding_value=pad_token)
            pad_q = torch.cat([sos, pad_q], dim=-1)
            # HACK: length == 0 can not forward in RNN
            fake_q_len = q_len.clone()
            fake_q_len[q_len == 0] = 1
            # For oracle vilbert
            # +1 : [CLS] token
            txt_attn_mask = [[1] * ql + [0] * (q_len.max().item()+1 - ql) for ql in q_len]
            txt_attn_mask = torch.tensor(txt_attn_mask).to(device)
            # a = self.oracle(pad_q, tgt_cat, tgt_bbox, tgt_img_feat, fake_q_len)
            a = self.oracle(
                pad_q, 
                tgt_cat, 
                tgt_bbox, 
                tgt_img_feat, 
                bboxs_rcnn_oracle, 
                image_features_rcnn_oracle,
                attention_mask=txt_attn_mask
                )
            a_confidence = nn.functional.softmax(a, dim=-1)
            a_idx = a.argmax(dim=-1)
            a = oracle_output_to_answer_token(a_idx, answer2id, answer2token)
            for b in range(batch_size):
                if not end_of_dialog[b]:
                    _q = q[b][:q_len[b]]
                    q_log[b].append(_q)
                    dialog[b] = torch.cat([dialog[b], _q])
                if not end_of_dialog_next[b]:
                    _a = a[b].view(-1)
                    a_log[b].append(_a)
                    a_conf_log[b].append(a_confidence[b, a_idx[b]])
                    dialog[b] = torch.cat([dialog[b], _a])

            if end_of_dialog_next.sum().item() == batch_size:
                break
            end_of_dialog = end_of_dialog_next
            last_wrd = a
            last_state = state
            pi = self.qgen.refresh_pi(pi, a, last_state[0,0], obj_repr, input_token=True)
            
        dial_len = torch.LongTensor([len(dial) for dial in dialog]).to(device)
        dial_pad = pad_sequence(dialog, batch_first=True, padding_value=pad_token)
        guess = self.guesser(dial_pad, dial_len, cats, bboxs_gt_gw, bboxs_mask)
        return guess, dialog, q_log, a_log, a_conf_log