in scripts/reader/train.py [0:0]
def eval_accuracies(pred_s, target_s, pred_e, target_e):
"""An unofficial evalutation helper.
Compute exact start/end/complete match accuracies for a batch.
"""
# Convert 1D tensors to lists of lists (compatibility)
if torch.is_tensor(target_s):
target_s = [[e.item()] for e in target_s]
target_e = [[e.item()] for e in target_e]
# Compute accuracies from targets
batch_size = len(pred_s)
start = utils.AverageMeter()
end = utils.AverageMeter()
em = utils.AverageMeter()
for i in range(batch_size):
# Start matches
if pred_s[i] in target_s[i]:
start.update(1)
else:
start.update(0)
# End matches
if pred_e[i] in target_e[i]:
end.update(1)
else:
end.update(0)
# Both start and end match
if any([1 for _s, _e in zip(target_s[i], target_e[i])
if _s == pred_s[i] and _e == pred_e[i]]):
em.update(1)
else:
em.update(0)
return start.avg * 100, end.avg * 100, em.avg * 100