weak_to_strong/eval.py (43 lines of code) (raw):
import datasets
import numpy as np
import torch
from torch import nn
def to_batch(x, batch_size):
for i in range(0, len(x), batch_size):
yield x[i : i + batch_size]
def unpack(x):
assert isinstance(x, torch.Tensor), type(x)
return x.detach().float().cpu().numpy().tolist()
def eval_model_acc(model: nn.Module, ds: datasets.Dataset, eval_batch_size: int = 16) -> None:
"""
This function evaluates the accuracy of a given model on a given dataset.
Parameters:
model (nn.Module): The model to be evaluated.
ds (datasets.Dataset): The dataset on which the model is to be evaluated.
Returns:
results (list): A list of dictionaries containing the input_ids, ground truth label, predicted label,
accuracy of prediction, logits and soft label for each example in the dataset.
"""
model.eval()
with torch.no_grad():
results = []
# for ex in ds:
for batch in to_batch(ds, eval_batch_size):
# pad input_ids to common length
input_ids = torch.nn.utils.rnn.pad_sequence(
[torch.tensor(ex) for ex in batch["input_ids"]], batch_first=True
).to(model.device if hasattr(model, "device") else "cpu")
labels = batch["soft_label"]
# run forward pass
raw_logits = model(input_ids)
probs = unpack(torch.nn.functional.softmax(raw_logits, dim=-1))
logits = unpack(raw_logits)
preds = np.argmax(probs, axis=-1)
labels = np.argmax(labels, axis=-1)
results.extend(
[
dict(
txt=txt,
input_ids=input_id,
gt_label=label,
hard_label=pred,
acc=label == pred,
logits=logit,
soft_label=prob,
)
for input_id, txt, label, pred, prob, logit in zip(
batch["input_ids"], batch["txt"], labels, preds, probs, logits
)
]
)
accs = [r["acc"] for r in results]
print("Accuracy:", np.mean(accs), "+/-", np.std(accs) / np.sqrt(len(accs)))
return datasets.Dataset.from_list(results)