src/data/self_play_qgen_vdst_guesser_vilbert.py [13:102]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    Game, 
    bbox2spatial_gw, 
    bbox2spatial_vilbert,
    add_global_vilbert_feats
)
from src.data.dataset import GuessWhatDataset

from tqdm import tqdm

logger = logging.getLogger(__name__)  # pylint: disable=invalid-name


class SelfPlayDataset(GuessWhatDataset):
    def __init__(
        self,
        dataroot,
        split,
        image_features_reader,
        tokenizer,
        **kwargs
    ):
        super().__init__(
            dataroot,
            'selfplay',
            split,
            image_features_reader,
            tokenizer,
            **kwargs)

    def _load_dataset(self):
        with jsonlines.open(self.data_path) as reader:
            # Build an index which maps image id with a list of qa annotations.
            entries = []
            for cur, annotation in tqdm(enumerate(reader)):
                # if cur >= 400: break
                game = Game.from_annotation(annotation)
                if game.status != 'success':
                    continue
                # dialog = []

                qs = []
                # q_len = []
                for qa in game.qas:
                    q = self._tokenizer.encode(qa['question'])
                    assert q[-1] == self.eoq_id,\
                        "There is question not ended with question mark in game-%d." % game.id
                    qs.append(q)
                    # q_len.append(len(q))
                

                    # a_token = self.answer2token[qa['answer']]
                #     dialog.extend(q_tokens + [a_token])

                item = dict()
                item['game'] = game
                item['image_id'] = game.image_id
                item['image_height'] = game.image_height
                item['image_width'] = game.image_width
                item['target_index'] = game.target_index
                item['categories'] = game.categories
                item['bboxs_gt_gw'] = [
                    bbox2spatial_gw(box, game.image_width, game.image_height) for box in game.bboxs]
                item['bboxs_gt_vb'] = [
                    bbox2spatial_vilbert(box, game.image_width, game.image_height) for box in game.bboxs]
                item['qs'] = qs
                # item['q_len'] = q_len
                entries.append(item)
        return entries

    def tensorize(self):
        for entry in self.entries:
            entry['target_index'] = torch.from_numpy(np.array(entry['target_index']))
            entry['categories'] = torch.from_numpy(np.array(entry['categories']))
            entry['bboxs_gt_gw'] = torch.from_numpy(np.array(entry['bboxs_gt_gw']))
            entry['bboxs_gt_vb'] = torch.from_numpy(np.array(entry['bboxs_gt_vb']))
            # item['image_feature_gt'] = self._image_features_reader_gt[game.image_id]
            # entry['qgen_image_features'] = torch.from_numpy(np.array(entry['qgen_image_features']))
            # entry['qgen_bboxs'] = torch.from_numpy(np.array(entry['qgen_bboxs']))

    def __getitem__(self, index):
        entry = self.entries[index]
        #image_id = entry['image_id']
        game = entry['game']
        image_id = entry['image_id']
        # qgen_state_track
        feats, bboxs, _ = self._image_features_reader['qgen'][image_id]
        image_features_rcnn_qgen = torch.from_numpy(np.array(feats))
        bboxs_rcnn_qgen = torch.from_numpy(np.array([
            bbox2spatial_gw(box, game.image_width, game.image_height, mode='xyxy') 
            for box in bboxs]))
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



src/data/self_play_qgen_vdst_oracle_vilbert_guesser_vilbert.py [13:102]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    Game, 
    bbox2spatial_gw, 
    bbox2spatial_vilbert,
    add_global_vilbert_feats
)
from src.data.dataset import GuessWhatDataset

from tqdm import tqdm

logger = logging.getLogger(__name__)  # pylint: disable=invalid-name


class SelfPlayDataset(GuessWhatDataset):
    def __init__(
        self,
        dataroot,
        split,
        image_features_reader,
        tokenizer,
        **kwargs
    ):
        super().__init__(
            dataroot,
            'selfplay',
            split,
            image_features_reader,
            tokenizer,
            **kwargs)

    def _load_dataset(self):
        with jsonlines.open(self.data_path) as reader:
            # Build an index which maps image id with a list of qa annotations.
            entries = []
            for cur, annotation in tqdm(enumerate(reader)):
                # if cur >= 400: break
                game = Game.from_annotation(annotation)
                if game.status != 'success':
                    continue
                # dialog = []

                qs = []
                # q_len = []
                for qa in game.qas:
                    q = self._tokenizer.encode(qa['question'])
                    assert q[-1] == self.eoq_id,\
                        "There is question not ended with question mark in game-%d." % game.id
                    qs.append(q)
                    # q_len.append(len(q))
                

                    # a_token = self.answer2token[qa['answer']]
                #     dialog.extend(q_tokens + [a_token])

                item = dict()
                item['game'] = game
                item['image_id'] = game.image_id
                item['image_height'] = game.image_height
                item['image_width'] = game.image_width
                item['target_index'] = game.target_index
                item['categories'] = game.categories
                item['bboxs_gt_gw'] = [
                    bbox2spatial_gw(box, game.image_width, game.image_height) for box in game.bboxs]
                item['bboxs_gt_vb'] = [
                    bbox2spatial_vilbert(box, game.image_width, game.image_height) for box in game.bboxs]
                item['qs'] = qs
                # item['q_len'] = q_len
                entries.append(item)
        return entries

    def tensorize(self):
        for entry in self.entries:
            entry['target_index'] = torch.from_numpy(np.array(entry['target_index']))
            entry['categories'] = torch.from_numpy(np.array(entry['categories']))
            entry['bboxs_gt_gw'] = torch.from_numpy(np.array(entry['bboxs_gt_gw']))
            entry['bboxs_gt_vb'] = torch.from_numpy(np.array(entry['bboxs_gt_vb']))
            # item['image_feature_gt'] = self._image_features_reader_gt[game.image_id]
            # entry['qgen_image_features'] = torch.from_numpy(np.array(entry['qgen_image_features']))
            # entry['qgen_bboxs'] = torch.from_numpy(np.array(entry['qgen_bboxs']))

    def __getitem__(self, index):
        entry = self.entries[index]
        #image_id = entry['image_id']
        game = entry['game']
        image_id = entry['image_id']
        # qgen_state_track
        feats, bboxs, _ = self._image_features_reader['qgen'][image_id]
        image_features_rcnn_qgen = torch.from_numpy(np.array(feats))
        bboxs_rcnn_qgen = torch.from_numpy(np.array([
            bbox2spatial_gw(box, game.image_width, game.image_height, mode='xyxy') 
            for box in bboxs]))
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



