in solver/oracle_vilbert.py [0:0]
def validate(self, specified_set, write_log=False, epoch=0):
self.model.eval()
total_loss = 0
total_hit = 0
cnt = 0
log = None
if write_log:
log = "game_id|question|answer|pred_answer|pred_confidence\n"
for val_step, data in enumerate(specified_set):
game, tgt_cat, tgt_bbox, tgt_img_feat, bg_bboxs, bg_img_feats, q_tokens, q_len, txt_attn_mask, answer = self.fetch_data(data)
with torch.no_grad():
pred = self.model(
q_tokens,
tgt_cat,
tgt_bbox,
tgt_img_feat,
bg_bboxs,
bg_img_feats,
update_vilbert=True,
attention_mask=txt_attn_mask,
)
loss = self.loss(pred, answer).item()
hit = (pred.argmax(dim=-1) == answer).sum().item()
total_hit += hit
total_loss += loss
cnt += len(answer)
if (val_step == 0) or ((val_step+1) % self._progress_step == 0):
self.progress("({}/{}) Dev stat. | Loss - {:.4f} | Acc. - {:.4f}".format(
val_step, len(specified_set), total_loss/float(cnt), total_hit/float(cnt)))
if write_log:
for b in range(len(game)):
# {'Yes': 0, 'No': 1, 'N/A': 2}
pred_idx = pred[b].argmax(dim=-1).item()
pred_ans = ans_idx2str(pred_idx)
pred_conf = nn.functional.softmax(pred[b], dim=-1)[pred_idx].item()
ans = ans_idx2str(answer[b].item())
log += "{}|{}|{}|{}|{:.3f}\n".format(
game[b].id,
self.tokenizer.decode(q_tokens[b].tolist(), ignore_pad=True),
ans,
pred_ans,
pred_conf
)
score = total_hit / float(cnt)
loss = total_loss / float(len(specified_set))
if self.distributed:
if self.main_proc:
self.save_checkpoint('epoch-%d.pth' % (epoch), score)
else:
if score > self.best_score and self.mode == 'train':
self.save_checkpoint('best.pth', score)
self.best_score = score
self.verbose(["Val stat. @ step {} | Loss - {:.4f} | Acc. - {:.4f}"
.format(self.step, loss, score)])
self.model.train()
return log