in codes/rnn_training/Non_transformers_probe.py [0:0]
def evaluate(nli_net = None, valid_iter = None, inv_label = None, itos_vocab = None, samples_file = None):
nli_net.eval()
pred = []
logits = []
for i, batch in enumerate(valid_iter):
# prepare batch
s1_batch, s1_len = batch.Sentence1
s2_batch, s2_len = batch.Sentence2
s1_batch, s2_batch = Variable(s1_batch.to(device)), Variable(s2_batch.to(device))
tgt_batch = batch.Label.to(device)
# model forward
output, (s1_out, s2_out) = nli_net((s1_batch, s1_len), (s2_batch, s2_len))
pred = [_.item() for _ in output.data.max(1)[1]]
logits = [_ for _ in output.cpu()]
#test_prediction = inv_label[pred[b_index].item()]
for b_index in range(len(batch)):
uid = batch.ContextID[b_index]
test_prediction = inv_label[pred[b_index]]
s1 = ' '.join([itos_vocab[idx.item()] for idx in batch.Sentence1[0][:batch.Sentence1[1][b_index],b_index]]).replace('Ġ',' ')
s2 = ' '.join([itos_vocab[idx.item()] for idx in batch.Sentence2[0][:batch.Sentence2[1][b_index],b_index]]).replace('Ġ',' ')
target = inv_label[batch.Label[b_index]]
logit = [_.item() for _ in output[b_index]]
is_correct = True if target == test_prediction else False
lock = FileLock(samples_file+'.lock')
with lock:
with open(samples_file,'a') as f:
f.write('{ uid:' uid + ', premise: '+s2 +', hypothesis: '+ s1 +', orig_label: ' + target +', model: '+test_prediction +', is_correct:' + str(is_correct) + ', logits:' + str(logit)+'}\n')
lock.release()
return pred, logits