in solver/self_play.py [0:0]
def validate(self, specified_set):
self.model.eval()
total_hit = 0
total_cnt = 0
out_file = open('{}.txt'.format(self.exp_name), 'w')
if not self.use_gt_question:
out_file.write('game_id|pred_obj|answer_obj|turn_id|question|answer\n')
for val_step, data in enumerate(specified_set):
game, img_feat, tgt_cat, tgt_bbox, cats, bboxs, bboxs_mask, label, qs, q_len = self.fetch_data(data)
with torch.no_grad():
if self.use_gt_question:
pred, dialog = self.model.play_with_gt_question(
qs, q_len, tgt_cat, tgt_bbox, cats, bboxs, bboxs_mask,
self.tokenizer.pad_id, self.answer2id, self.answer2token)
else:
pred, dialog, q_log, a_log = self.model.play(
img_feat, tgt_cat, tgt_bbox, cats, bboxs, bboxs_mask,
self.tokenizer.sos_id, self.tokenizer.pad_id,
self.tokenizer.eoq_id, self.tokenizer.eod_id,
self.answer2id, self.answer2token,
max_q_len=20, greedy=True, max_turns=5
)
if self.use_gt_question:
for b in range(pred.size(0)):
out_file.write("{}-{}/{} | {}\n".format(
game[b].id, pred[b].argmax(dim=-1).item(), label[b].item(), self.tokenizer.decode(dialog[b].tolist())))
else:
for b in range(pred.size(0)):
out_prefix = "{}|{}|{}".format(game[b].id, pred[b].argmax(dim=-1).item(), label[b].item())
for t in range(len(q_log[b])):
out_str = out_prefix + "|{}|{}|".format(t, self.tokenizer.decode(q_log[b][t].tolist()))
if t != len(q_log[b])-1:
out_str += "{}".format(self.tokenizer.decode(a_log[b][t].tolist()))
out_file.write(out_str+'\n')
total_hit += (pred.argmax(dim=-1) == label).sum().item()
total_cnt += pred.size(0)
# if (val_step == 0) or ((val_step+1) % self._progress_step == 0):
self.progress("Dev stat. | Acc. - {:.3f}".format(total_hit/float(total_cnt)))
# Log
if self.mode == 'train':
NOT_IMPLEMENT_YET()
if self.mode == 'train':
score = -avg_loss
if score > self.best_score:
#self.save_checkpoint('step_{}.pth'.format(self.step), score)
self.save_checkpoint('best.pth', score)
self.best_score = score
self.model.train()
self.verbose(["Val stat. @ step {} | Acc. - {:.3f}"
.format(self.step, total_hit / float(total_cnt))])
out_file.close()