solver/self_play_all_vilbert.py [100:211]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            image_features_rcnn_oracle.to(self.device), 
            bboxs_rcnn_oracle.to(self.device), 
            image_features_rcnn_gt_guesser.to(self.device),
            bboxs_rcnn_gt_guesser.to(self.device),
            tgt_img_feat.to(self.device),
            tgt_bbox_vb.to(self.device),
            tgt_cat.to(self.device), 
            cats_guesser.to(self.device), 
            bboxs_mask.to(self.device), 
            label.to(self.device), 
            qs.to(self.device), 
            q_len.to(self.device),
        )

    def set_model(self):
        self.verbose(['Set model...'])
        self.use_gt_question = 'qgen' not in self.config['model']
        players = ['qgen', 'oracle', 'guesser'] if not self.use_gt_question else \
                  ['oracle', 'guesser']

        for plyr in players:
            self.config['model'][plyr]['num_wrds'] = len(self.tokenizer)
            self.config['model'][plyr]['wrd_pad_id'] = self.tokenizer.pad_id

        self.model = SelfPlayModel(
            qgen_kwargs=None if self.use_gt_question else self.config['model']['qgen'] ,
            oracle_kwargs=self.config['model']['oracle'],
            guesser_kwargs=self.config['model']['guesser']
            )
        # Load pretrained players
        for plyr in players:
            log = self.model.load_player(
                plyr, self.config['model'][plyr]['pretrained_path'], map_location="cpu")
            self.verbose([log])
        self.model.to(self.device)
        self.optimizer = Optimizer(
            self.model.parameters(), **self.config['hparas'])
        # self.loss = nn.CrossEntropyLoss(reduction='sum')
        self.loss = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_id)

        if self.args.load: # not so useful for self-play
            ckpt = torch.load(self.args.load, map_location=self.device)
            self.model.load_state_dict(ckpt['model'])
            self.optimizer.load_opt_state_dict(ckpt['optimizer'])
            self.step = ckpt['global_step']
            self.verbose('Load ckpt from {}, restarting at step {}'.format(
                self.args.load, self.step))

    def exec(self):
        if self.use_gt_question:
            self.verbose("Use ground truth questions.")
        if self.mode == 'train':
            self.train()
        else:
            self.verbose(["Evaluate on test set..."])
            self.validate(self.test_set)
            # self.verbose(["Evaluate on valid set..."])
            # self.validate(self.valid_set)
        

    def train(self):
        self.verbose(['Total training epoch/steps: {}/{}'.format(
            self.max_epoch, human_format(self.max_step))])
        self.verbose(['Number of steps per epoch: {}'.format(
            human_format(self.steps_per_epoch))])
        while self.step < self.max_step:
            # Validate every epoch
            self.validate(self.valid_set)
            self.timer.set()
            for data in self.train_set:
                game, obj_feats, tgt_cat, tgt_bbox, tgt_img_feat, cats, bboxs, bboxs_mask, label, qs, q_len = self.fetch_data(data)
                NOT_IMPLEMENT_YET()
                self.timer.cnt('rd')
                # Forward
                self.optimizer.pre_step(self.step)
                pred = self.model(qgen_in, qgen_in_len, img_feat, mask=None)
                loss = self.loss(pred.view(-1, pred.size(-1)), qgen_tgt.view(-1))
                hit = cal_hit(pred, qgen_tgt)
                acc = hit / float(qgen_tgt.size(0) * qgen_tgt.size(1))
                hit_nopad = cal_hit(pred, qgen_tgt, pad_id=self.tokenizer.pad_id)
                acc_nopad = hit_nopad / float(qgen_tgt.size(0) * qgen_tgt.size(1))
                self.timer.cnt('fw')
                # Backward
                grad_norm = self.backward(loss)
                self.timer.cnt('bw')

                self.step += 1
                # Log
                if (self.step == 1) or (self.step % self._progress_step == 0):
                    self.progress("Tr stat. | Loss - {:.4f} | Acc.(pad/nopad) - {:.3f}/{:.3f} | Grad. norm - {:.2f} | {}".format(
                        loss.item(), acc, acc_nopad, grad_norm, self.timer.show()))            
                    self.write_log('scalars', 'accuracy', {'train': acc})
                    self.write_log('scalars', 'accuracy', {'train-nopad': acc_nopad})
                    self.write_log('scalars', 'loss', {'train': loss})

                # End of step
                self.timer.set()
                if self.step > self.max_step:
                    self.verbose("Reach max training step.")
                    self.logger.close()
                    break


    def validate(self, specified_set):
        self.model.eval()
        total_hit = 0
        total_cnt = 0
        out_name = self.exp_name + ('_gt' if self.use_gt_question else '') + '.txt'
        out_file = open(out_name, 'w')
        out_file.write('game_id|pred_obj|answer_obj|turn_id|question|answer|answer_confidence\n')

        for val_step, data in enumerate(specified_set):
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



solver/self_play_qgen_vdst_oracle_vilbert_guesser_vilbert.py [100:211]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            image_features_rcnn_oracle.to(self.device), 
            bboxs_rcnn_oracle.to(self.device), 
            image_features_rcnn_gt_guesser.to(self.device),
            bboxs_rcnn_gt_guesser.to(self.device),
            tgt_img_feat.to(self.device),
            tgt_bbox_vb.to(self.device),
            tgt_cat.to(self.device), 
            cats_guesser.to(self.device), 
            bboxs_mask.to(self.device), 
            label.to(self.device), 
            qs.to(self.device), 
            q_len.to(self.device),
        )

    def set_model(self):
        self.verbose(['Set model...'])
        self.use_gt_question = 'qgen' not in self.config['model']
        players = ['qgen', 'oracle', 'guesser'] if not self.use_gt_question else \
                  ['oracle', 'guesser']

        for plyr in players:
            self.config['model'][plyr]['num_wrds'] = len(self.tokenizer)
            self.config['model'][plyr]['wrd_pad_id'] = self.tokenizer.pad_id

        self.model = SelfPlayModel(
            qgen_kwargs=None if self.use_gt_question else self.config['model']['qgen'] ,
            oracle_kwargs=self.config['model']['oracle'],
            guesser_kwargs=self.config['model']['guesser']
            )
        # Load pretrained players
        for plyr in players:
            log = self.model.load_player(
                plyr, self.config['model'][plyr]['pretrained_path'], map_location="cpu")
            self.verbose([log])
        self.model.to(self.device)
        self.optimizer = Optimizer(
            self.model.parameters(), **self.config['hparas'])
        # self.loss = nn.CrossEntropyLoss(reduction='sum')
        self.loss = nn.CrossEntropyLoss(ignore_index=self.tokenizer.pad_id)

        if self.args.load:
            ckpt = torch.load(self.args.load, map_location=self.device)
            self.model.load_state_dict(ckpt['model'])
            self.optimizer.load_opt_state_dict(ckpt['optimizer'])
            self.step = ckpt['global_step']
            self.verbose('Load ckpt from {}, restarting at step {}'.format(
                self.args.load, self.step))

    def exec(self):
        if self.use_gt_question:
            self.verbose("Use ground truth questions.")
        if self.mode == 'train':
            self.train()
        else:
            self.verbose(["Evaluate on test set..."])
            self.validate(self.test_set)
            # self.verbose(["Evaluate on valid set..."])
            # self.validate(self.valid_set)
        

    def train(self):
        self.verbose(['Total training epoch/steps: {}/{}'.format(
            self.max_epoch, human_format(self.max_step))])
        self.verbose(['Number of steps per epoch: {}'.format(
            human_format(self.steps_per_epoch))])
        while self.step < self.max_step:
            # Validate every epoch
            self.validate(self.valid_set)
            self.timer.set()
            for data in self.train_set:
                game, obj_feats, tgt_cat, tgt_bbox, tgt_img_feat, cats, bboxs, bboxs_mask, label, qs, q_len = self.fetch_data(data)
                NOT_IMPLEMENT_YET()
                self.timer.cnt('rd')
                # Forward
                self.optimizer.pre_step(self.step)
                pred = self.model(qgen_in, qgen_in_len, img_feat, mask=None)
                loss = self.loss(pred.view(-1, pred.size(-1)), qgen_tgt.view(-1))
                hit = cal_hit(pred, qgen_tgt)
                acc = hit / float(qgen_tgt.size(0) * qgen_tgt.size(1))
                hit_nopad = cal_hit(pred, qgen_tgt, pad_id=self.tokenizer.pad_id)
                acc_nopad = hit_nopad / float(qgen_tgt.size(0) * qgen_tgt.size(1))
                self.timer.cnt('fw')
                # Backward
                grad_norm = self.backward(loss)
                self.timer.cnt('bw')

                self.step += 1
                # Log
                if (self.step == 1) or (self.step % self._progress_step == 0):
                    self.progress("Tr stat. | Loss - {:.4f} | Acc.(pad/nopad) - {:.3f}/{:.3f} | Grad. norm - {:.2f} | {}".format(
                        loss.item(), acc, acc_nopad, grad_norm, self.timer.show()))            
                    self.write_log('scalars', 'accuracy', {'train': acc})
                    self.write_log('scalars', 'accuracy', {'train-nopad': acc_nopad})
                    self.write_log('scalars', 'loss', {'train': loss})

                # End of step
                self.timer.set()
                if self.step > self.max_step:
                    self.verbose("Reach max training step.")
                    self.logger.close()
                    break


    def validate(self, specified_set):
        self.model.eval()
        total_hit = 0
        total_cnt = 0
        out_name = self.exp_name + ('_gt' if self.use_gt_question else '') + '.txt'
        out_file = open(out_name, 'w')
        out_file.write('game_id|pred_obj|answer_obj|turn_id|question|answer|answer_confidence\n')

        for val_step, data in enumerate(specified_set):
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



