in blink/crossencoder/train_cross.py [0:0]
def evaluate(reranker, eval_dataloader, device, logger, context_length, zeshel=False, silent=True):
reranker.model.eval()
if silent:
iter_ = eval_dataloader
else:
iter_ = tqdm(eval_dataloader, desc="Evaluation")
results = {}
eval_accuracy = 0.0
nb_eval_examples = 0
nb_eval_steps = 0
acc = {}
tot = {}
world_size = len(WORLDS)
for i in range(world_size):
acc[i] = 0.0
tot[i] = 0.0
all_logits = []
cnt = 0
for step, batch in enumerate(iter_):
if zeshel:
src = batch[2]
cnt += 1
batch = tuple(t.to(device) for t in batch)
context_input = batch[0]
label_input = batch[1]
with torch.no_grad():
eval_loss, logits = reranker(context_input, label_input, context_length)
logits = logits.detach().cpu().numpy()
label_ids = label_input.cpu().numpy()
tmp_eval_accuracy, eval_result = utils.accuracy(logits, label_ids)
eval_accuracy += tmp_eval_accuracy
all_logits.extend(logits)
nb_eval_examples += context_input.size(0)
if zeshel:
for i in range(context_input.size(0)):
src_w = src[i].item()
acc[src_w] += eval_result[i]
tot[src_w] += 1
nb_eval_steps += 1
normalized_eval_accuracy = -1
if nb_eval_examples > 0:
normalized_eval_accuracy = eval_accuracy / nb_eval_examples
if zeshel:
macro = 0.0
num = 0.0
for i in range(len(WORLDS)):
if acc[i] > 0:
acc[i] /= tot[i]
macro += acc[i]
num += 1
if num > 0:
logger.info("Macro accuracy: %.5f" % (macro / num))
logger.info("Micro accuracy: %.5f" % normalized_eval_accuracy)
else:
if logger:
logger.info("Eval accuracy: %.5f" % normalized_eval_accuracy)
results["normalized_accuracy"] = normalized_eval_accuracy
results["logits"] = all_logits
return results