def validate()

in solver/oracle_vilbert.py [0:0]


    def validate(self, specified_set, write_log=False, epoch=0):
        self.model.eval()
        total_loss = 0
        total_hit = 0
        cnt = 0
        log = None
        if write_log:
            log = "game_id|question|answer|pred_answer|pred_confidence\n"
        for val_step, data in enumerate(specified_set):
            game, tgt_cat, tgt_bbox, tgt_img_feat, bg_bboxs, bg_img_feats, q_tokens, q_len, txt_attn_mask, answer = self.fetch_data(data)
            with torch.no_grad():
                pred = self.model(
                    q_tokens, 
                    tgt_cat, 
                    tgt_bbox, 
                    tgt_img_feat, 
                    bg_bboxs, 
                    bg_img_feats, 
                    update_vilbert=True, 
                    attention_mask=txt_attn_mask,
                )
                loss = self.loss(pred, answer).item()
                hit = (pred.argmax(dim=-1) == answer).sum().item()
                total_hit += hit
                total_loss += loss
                cnt += len(answer)
                if (val_step == 0) or ((val_step+1) % self._progress_step == 0):
                    self.progress("({}/{}) Dev stat. | Loss - {:.4f} | Acc. - {:.4f}".format(
                        val_step, len(specified_set), total_loss/float(cnt), total_hit/float(cnt)))

                if write_log:
                    for b in range(len(game)):
                        # {'Yes': 0, 'No': 1, 'N/A': 2}
                        pred_idx = pred[b].argmax(dim=-1).item()
                        pred_ans = ans_idx2str(pred_idx)
                        pred_conf = nn.functional.softmax(pred[b], dim=-1)[pred_idx].item()
                        ans = ans_idx2str(answer[b].item())
                        log += "{}|{}|{}|{}|{:.3f}\n".format(
                                game[b].id,
                                self.tokenizer.decode(q_tokens[b].tolist(), ignore_pad=True),
                                ans,
                                pred_ans,
                                pred_conf
                            )

        score = total_hit / float(cnt)
        loss = total_loss / float(len(specified_set))
        if self.distributed:
            if self.main_proc:
                self.save_checkpoint('epoch-%d.pth' % (epoch), score)
        else:
            if score > self.best_score and self.mode == 'train':
                self.save_checkpoint('best.pth', score)
                self.best_score = score

        self.verbose(["Val stat. @ step {} | Loss - {:.4f} | Acc. - {:.4f}"
                      .format(self.step, loss, score)])

        self.model.train()
        return log