in solver/qgen_vdst.py [0:0]
def validate(self, specified_set):
self.model.eval()
total_loss = 0
for val_step, data in enumerate(specified_set):
with torch.no_grad():
game, qs, tf_input, answers, q_len, obj_feats = self.fetch_data(data)
pred, _, _ = self.model.forward_dialog(tf_input, q_len, answers, obj_feats)
loss = self.loss(pred.reshape(-1, pred.size(-1)), qs.reshape(-1))
total_loss += loss
if (val_step == 0) or ((val_step+1) % self._progress_step == 0):
self.progress("Dev stat. | Loss - {:.4f}".format(
total_loss/float(val_step+1)))
# Log
if self.mode == 'train':
if val_step < NUM_LOG_TEXT_SAMPLES:
batch_size = pred.size(0)
num_turns = pred.size(1)
device = pred.device
dial_pred = ""
dial_tgt = ""
for t in range(num_turns):
dial_pred += self.tokenizer.decode(pred[0][t].argmax(dim=-1).cpu().numpy())
dial_tgt += self.tokenizer.decode(qs[0][t].cpu().numpy())
self.write_log('text', '(teacher-forcing)-pred-%d' % val_step, dial_pred)
if self.step == 0:
self.write_log('text', '(teacher-forcing)-target-%d' % val_step, dial_tgt)
last_wrd = torch.zeros(batch_size).fill_(self.tokenizer.sos_id).long().to(device)
end_of_dialog = torch.zeros(batch_size).bool().to(device)
# for t in range(NUM_TURNS_FOR_VALIDATION):
q_tokens, actual_length, _, _, _ = self.model.generate_sentence(
last_wrd, obj_feats, self.tokenizer.eoq_id, self.tokenizer.eod_id, end_of_dialog,
max_q_len=12, pi=None, last_state=None, greedy=True)
# Take the first one
pred_sent = self.tokenizer.decode(q_tokens[0].cpu().numpy())
self.write_log('text', '(free-run)-pred-%d' % val_step, pred_sent)
avg_loss = total_loss / float(len(specified_set))
self.write_log('scalars', 'loss', {'dev': avg_loss})
score = -avg_loss
if score > self.best_score and self.mode == 'train':
self.save_checkpoint('best.pth', score)
self.best_score = score
self.verbose(["Val stat. @ step {} | Loss - {:.4f}"
.format(self.step, avg_loss)])
self.model.train()