in src/nli/training_extra.py [0:0]
def eval_model(model, dev_dataloader, device_num, args):
model.eval()
uid_list = []
y_list = []
pred_list = []
logits_list = []
with torch.no_grad():
for i, batch in enumerate(dev_dataloader, 0):
batch = move_to_device(batch, device_num)
if args.model_class_name in ["distilbert", "bart-large", 'lstm-resencoder', "bag-of-words"]:
outputs = model(batch['input_ids'],
attention_mask=batch['attention_mask'],
labels=batch['y'])
else:
outputs = model(batch['input_ids'],
attention_mask=batch['attention_mask'],
token_type_ids=batch['token_type_ids'],
labels=batch['y'])
loss, logits = outputs[:2]
uid_list.extend(list(batch['uid']))
y_list.extend(batch['y'].tolist())
pred_list.extend(torch.max(logits, 1)[1].view(logits.size(0)).tolist())
logits_list.extend(logits.tolist())
assert len(pred_list) == len(logits_list)
assert len(pred_list) == len(logits_list)
result_items_list = []
for i in range(len(uid_list)):
r_item = dict()
r_item['uid'] = uid_list[i]
r_item['logits'] = logits_list[i]
r_item['predicted_label'] = id2label[pred_list[i]]
result_items_list.append(r_item)
return result_items_list