in solver/guesser_vilbert.py [0:0]
def validate(self, specified_set, write_log=False):
self.model.eval()
total_loss = 0
total_hit = 0
cnt = 0
if write_log:
out_file = open('{}_stat_his.txt'.format(self.exp_name), 'w')
for val_step, data in enumerate(specified_set):
game, qs, qs_len, answers, end_turn, cats, img_feats, bboxs, bboxs_mask, bboxs_mask_vb, txt_attn_mask, label = self.fetch_data(data)
with torch.no_grad():
pred = self.model.forward_session(
qs, answers, end_turn, cats, img_feats, bboxs,
bboxs_mask=bboxs_mask,
attention_mask=txt_attn_mask,
image_attention_mask=bboxs_mask_vb,
return_state_history=self.mode=='test'
)
if self.mode == 'test':
pred, stat_his = pred
loss = self.loss(pred, label).item()
hit = (pred.argmax(dim=-1) == label).sum().item()
total_hit += hit
total_loss += loss
cnt += len(label)
if (val_step == 0) or ((val_step+1) % self._progress_step == 0):
self.progress("Dev stat. ({}/{}) | Loss - {:.4f} | Acc. - {:.4f}".format(
val_step, len(specified_set), total_loss/float(cnt), total_hit/float(cnt)))
if self.mode == 'test':
if write_log:
for g, l, p, stat_h in zip(game, label, pred.argmax(dim=-1), stat_his):
out_file.write("{}|{}|{}|{}\n".format(g.id, l, p,stat_h.tolist()))
score = total_hit / float(cnt)
loss = total_loss / float(len(specified_set))
if score > self.best_score and self.mode == 'train':
self.save_checkpoint('best.pth', score)
self.best_score = score
self.verbose(["Val stat. @ step {} | Loss - {:.4f} | Acc. - {:.4f}"
.format(self.step, loss, score)])
self.model.train()
if write_log:
out_file.close()