bench/generation/metrics/prediction.py (24 lines of code) (raw):

# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import time import torch from datasets import load_dataset @torch.no_grad() def prediction_accuracy(model, tokenizer, batch_size, samples=None): test_dataset = load_dataset("lambada", split=["test"])[0] model.eval() # The task is to predict the last token of the input. total, hit = 0, 0 start = time.time() for batch in test_dataset.iter(batch_size=batch_size): inputs = tokenizer(batch["text"], return_tensors="pt", padding=True) input_ids = inputs.input_ids.to(model.device) attention_mask = inputs.attention_mask.to(model.device) labels = input_ids[:, -1] # Pass only the first tokens outputs = model(input_ids[:, :-1], attention_mask=attention_mask[:, :-1]) preds = outputs.logits[:, -1, :].argmax(dim=-1) total += labels.size(0) hit += (preds == labels).sum().item() if samples is not None and total >= samples: break end = time.time() acc = hit / total print(f"{total} sequences evaluated in {end - start:.2f} s. accuracy = {acc:.2f}") return acc