src/data/self_play_all_vilbert.py [105:231]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        feats, _, bboxs, _ = self._image_features_reader['oracle'][image_id]
        image_features_rcnn_oracle = torch.from_numpy(np.array(feats))
        bboxs_rcnn_oracle = torch.from_numpy(np.array(bboxs))
        # bboxs_rcnn_oracle = torch.from_numpy(np.array([
        #     bbox2spatial_vilbert(box, game.image_width, game.image_height, mode='xyxy') 
        #     for box in bboxs]))
        # gt
        feats, bboxs, _ = self._image_features_reader_gt[image_id]
        # image_features_rcnn_gt = torch.from_numpy(np.array(feats))
        image_features_rcnn_gt = torch.from_numpy(np.array(feats))
        bboxs_gt_gw = entry['bboxs_gt_gw']
        bboxs_gt_vb = entry['bboxs_gt_vb']

        tgt_index = entry['target_index']
        cats = entry['categories']
        # bboxs = entry['bboxs']
        
        tgt_cat = cats[tgt_index]
        tgt_bbox_gw = bboxs_gt_gw[tgt_index]
        tgt_bbox_vb = bboxs_gt_vb[tgt_index]
        tgt_img_feat = image_features_rcnn_gt[tgt_index]
        label = torch.LongTensor([tgt_index])

        # Add global information (visual side)
        image_features_rcnn_gt_guesser, bboxs_rcnn_gt_guesser = add_global_vilbert_feats(
            image_features_rcnn_gt.float(), bboxs_gt_vb.float(), input_torch=True)
        cats_guesser = torch.cat([torch.LongTensor([99]) , cats], dim=0)
        qs = entry['qs']
        # q_len = entry['q_len']

        return (
            game,
            image_features_rcnn_qgen,
            bboxs_rcnn_qgen,
            image_features_rcnn_oracle,
            bboxs_rcnn_oracle,
            image_features_rcnn_gt_guesser,
            bboxs_rcnn_gt_guesser,
            # bboxs_gt_gw,
            # bboxs_gt_vb,
            tgt_img_feat,
            # tgt_bbox_gw,
            tgt_bbox_vb,
            tgt_cat,
            cats_guesser,
            label,
            qs,
            # q_len
        )


def collate_fn(batch, wrd_pad_id):
    batch_size = len(batch)
    # batch
    # game, qgen_img_feats, qgen_bboxs, tgt_cat, tgt_bbox, cats, bboxs, label, qs = zip(*batch)
    game, image_features_rcnn_qgen, bboxs_rcnn_qgen, image_features_rcnn_oracle, bboxs_rcnn_oracle, \
        image_features_rcnn_gt_guesser, bboxs_rcnn_gt_guesser, tgt_img_feat, tgt_bbox_vb, tgt_cat, \
            cats_guesser, label, qs = zip(*batch)
    # Dealing with ground truth questions
    # qs: [batch size, turns (not padded), seq len (not padded)]
    
    qs = list(qs)
    max_q_seq_len = max([len(q) for _qs in qs for q in _qs])
    max_num_turns = max([len(_qs) for _qs in qs])
    q_len = []
    for b in range(len(qs)):
        _q_len = []
        for t in range(len(qs[b])):
            _q_len.append(len(qs[b][t]))
            # Pad each sentence to max_q_seq_len
            qs[b][t].extend((max_q_seq_len - len(qs[b][t])) * [wrd_pad_id])
            assert len(qs[b][t]) == max_q_seq_len
        # Pad each batch to max_num_turns
        qs[b] += (max_num_turns - len(qs[b])) * [max_q_seq_len * [wrd_pad_id]]
        q_len.append(_q_len + (max_num_turns - len(_q_len)) * [0])
        assert len(qs[b]) == max_num_turns
        assert len(q_len[b]) == max_num_turns
        # q_len: [batch size, turns (not padded)]            

    qs = torch.LongTensor(qs)
    q_len = torch.LongTensor(q_len)
    
    tgt_cat = torch.stack(tgt_cat).long()
    # tgt_bbox_gw = torch.stack(tgt_bbox_gw).float()
    tgt_bbox_vb = torch.stack(tgt_bbox_vb).float()
    tgt_img_feat = torch.stack(tgt_img_feat).float()
    # (batch_size, padded_num_obj)
    cats_guesser = pad_sequence(cats_guesser, batch_first=True).long()
    # (batch_size, padded_num_bboxs, feat dim)
    # img_feat = pad_sequence(img_feat, batch_first=True).float()
    # (batch_size, padded_num_bboxs, spatial dim)
    # bboxs_gt_gw = pad_sequence(bboxs_gt_gw, batch_first=True).float()
    # bboxs_gt_vb = pad_sequence(bboxs_gt_vb, batch_first=True).float()
    # (batch_size, padded_num_obj)
    bboxs_mask = [torch.ones(len(xs)) for xs in bboxs_rcnn_gt_guesser]
    bboxs_mask = pad_sequence(bboxs_mask, batch_first=True).bool()
    # (batch_size, padded_seq_length)
    label = torch.stack(label).view(-1)

    image_features_rcnn_qgen = torch.stack(image_features_rcnn_qgen).float()
    bboxs_rcnn_qgen = torch.stack(bboxs_rcnn_qgen).float()
    # qgen_obj_feats = torch.cat([qgen_img_feats, qgen_bboxs], dim=-1)

    image_features_rcnn_oracle = torch.stack(image_features_rcnn_oracle).float()
    bboxs_rcnn_oracle = torch.stack(bboxs_rcnn_oracle).float()

    # image_features_rcnn_gt_guesser = torch.stack(image_features_rcnn_gt_guesser).float()
    image_features_rcnn_gt_guesser = pad_sequence(image_features_rcnn_gt_guesser, batch_first=True).float()
    bboxs_rcnn_gt_guesser = pad_sequence(bboxs_rcnn_gt_guesser, batch_first=True).float()

    return (
        game, 
        image_features_rcnn_qgen, 
        bboxs_rcnn_qgen, 
        image_features_rcnn_oracle, 
        bboxs_rcnn_oracle, 
        image_features_rcnn_gt_guesser,
        bboxs_rcnn_gt_guesser,
        tgt_img_feat,
        tgt_bbox_vb,
        tgt_cat, 
        cats_guesser, 
        bboxs_mask, 
        label, 
        qs, 
        q_len
        )
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



src/data/self_play_qgen_vdst_oracle_vilbert_guesser_vilbert.py [105:224]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        feats, _, bboxs, _ = self._image_features_reader['oracle'][image_id]
        image_features_rcnn_oracle = torch.from_numpy(np.array(feats))
        bboxs_rcnn_oracle = torch.from_numpy(np.array(bboxs))
        # gt
        feats, bboxs, _ = self._image_features_reader_gt[image_id]
        # image_features_rcnn_gt = torch.from_numpy(np.array(feats))
        image_features_rcnn_gt = torch.from_numpy(np.array(feats))
        bboxs_gt_gw = entry['bboxs_gt_gw']
        bboxs_gt_vb = entry['bboxs_gt_vb']

        tgt_index = entry['target_index']
        cats = entry['categories']
        # bboxs = entry['bboxs']
        
        tgt_cat = cats[tgt_index]
        tgt_bbox_gw = bboxs_gt_gw[tgt_index]
        tgt_bbox_vb = bboxs_gt_vb[tgt_index]
        tgt_img_feat = image_features_rcnn_gt[tgt_index]
        label = torch.LongTensor([tgt_index])

        # Add global information (visual side)
        image_features_rcnn_gt_guesser, bboxs_rcnn_gt_guesser = add_global_vilbert_feats(
            image_features_rcnn_gt.float(), bboxs_gt_vb.float(), input_torch=True)
        cats_guesser = torch.cat([torch.LongTensor([99]) , cats], dim=0)
        qs = entry['qs']
        # q_len = entry['q_len']

        return (
            game,
            image_features_rcnn_qgen,
            bboxs_rcnn_qgen,
            image_features_rcnn_oracle,
            bboxs_rcnn_oracle,
            image_features_rcnn_gt_guesser,
            bboxs_rcnn_gt_guesser,
            # bboxs_gt_gw,
            # bboxs_gt_vb,
            tgt_img_feat,
            # tgt_bbox_gw,
            tgt_bbox_vb,
            tgt_cat,
            cats_guesser,
            label,
            qs,
            # q_len
        )


def collate_fn(batch, wrd_pad_id):
    batch_size = len(batch)
    # batch
    # game, qgen_img_feats, qgen_bboxs, tgt_cat, tgt_bbox, cats, bboxs, label, qs = zip(*batch)
    game, image_features_rcnn_qgen, bboxs_rcnn_qgen, image_features_rcnn_oracle, bboxs_rcnn_oracle, \
        image_features_rcnn_gt_guesser, bboxs_rcnn_gt_guesser, tgt_img_feat, tgt_bbox_vb, tgt_cat, \
            cats_guesser, label, qs = zip(*batch)
    # Dealing with ground truth questions
    # qs: [batch size, turns (not padded), seq len (not padded)]
    
    qs = list(qs)
    max_q_seq_len = max([len(q) for _qs in qs for q in _qs])
    max_num_turns = max([len(_qs) for _qs in qs])
    q_len = []
    for b in range(len(qs)):
        _q_len = []
        for t in range(len(qs[b])):
            _q_len.append(len(qs[b][t]))
            # Pad each sentence to max_q_seq_len
            qs[b][t].extend((max_q_seq_len - len(qs[b][t])) * [wrd_pad_id])
            assert len(qs[b][t]) == max_q_seq_len
        # Pad each batch to max_num_turns
        qs[b] += (max_num_turns - len(qs[b])) * [max_q_seq_len * [wrd_pad_id]]
        q_len.append(_q_len + (max_num_turns - len(_q_len)) * [0])
        assert len(qs[b]) == max_num_turns
        assert len(q_len[b]) == max_num_turns
        # q_len: [batch size, turns (not padded)]            

    qs = torch.LongTensor(qs)
    q_len = torch.LongTensor(q_len)
    
    tgt_cat = torch.stack(tgt_cat).long()
    # tgt_bbox_gw = torch.stack(tgt_bbox_gw).float()
    tgt_bbox_vb = torch.stack(tgt_bbox_vb).float()
    tgt_img_feat = torch.stack(tgt_img_feat).float()
    # (batch_size, padded_num_obj)
    cats_guesser = pad_sequence(cats_guesser, batch_first=True).long()

    # (batch_size, padded_num_obj)
    bboxs_mask = [torch.ones(len(xs)) for xs in bboxs_rcnn_gt_guesser]
    bboxs_mask = pad_sequence(bboxs_mask, batch_first=True).bool()
    # (batch_size, padded_seq_length)
    label = torch.stack(label).view(-1)

    image_features_rcnn_qgen = torch.stack(image_features_rcnn_qgen).float()
    bboxs_rcnn_qgen = torch.stack(bboxs_rcnn_qgen).float()
    # qgen_obj_feats = torch.cat([qgen_img_feats, qgen_bboxs], dim=-1)

    image_features_rcnn_oracle = torch.stack(image_features_rcnn_oracle).float()
    bboxs_rcnn_oracle = torch.stack(bboxs_rcnn_oracle).float()

    # image_features_rcnn_gt_guesser = torch.stack(image_features_rcnn_gt_guesser).float()
    image_features_rcnn_gt_guesser = pad_sequence(image_features_rcnn_gt_guesser, batch_first=True).float()
    bboxs_rcnn_gt_guesser = pad_sequence(bboxs_rcnn_gt_guesser, batch_first=True).float()

    return (
        game, 
        image_features_rcnn_qgen, 
        bboxs_rcnn_qgen, 
        image_features_rcnn_oracle, 
        bboxs_rcnn_oracle, 
        image_features_rcnn_gt_guesser,
        bboxs_rcnn_gt_guesser,
        tgt_img_feat,
        tgt_bbox_vb,
        tgt_cat, 
        cats_guesser, 
        bboxs_mask, 
        label, 
        qs, 
        q_len
        )
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



