models/utils.py (31 lines of code) (raw):

import re import torch # Used to check our models performance on multiple choice tasks. This can also be done in a more involved way with e.g. LLM-as-a-judge def check_multiple_choice_with_regex(model_outputs, correct_answers): results = [] for model_output, correct_answer in zip(model_outputs, correct_answers): # Strip any trailing newlines and convert to uppercase correct_answer = correct_answer.rstrip('\n').upper() # Look for the answer letter at the beginning of a line or as the last word patterns = [ rf"\b{correct_answer}\b", # Word boundary around the answer letter rf"\b{correct_answer}[.,)]", # Answer followed by punctuation rf"\(.*{correct_answer}.*\)", # Answer within parentheses ] match_found = False for pattern in patterns: if re.search(pattern, model_output): match_found = True break # Exit inner loop once a match is found results.append(match_found) return results def top_k_top_p_filtering(logits, top_k=0, top_p=1.0, filter_value=-float('Inf')): """ Apply top-k and/or nucleus (top-p) filtering to logits. """ top_k = min(top_k, logits.size(-1)) # Safety if top_k > 0: # Remove all tokens with a probability less than the top-k tokens indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits = logits.masked_fill(indices_to_remove, filter_value) if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.softmax(sorted_logits, dim=-1).cumsum(dim=-1) # Remove tokens with cumulative probability above top_p sorted_indices_to_remove = cumulative_probs > top_p # Always keep the first token sorted_indices_to_remove[..., 0] = False indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) logits = logits.masked_fill(indices_to_remove, filter_value) return logits