lm_eval/tasks/mmlu/flan_cot_zeroshot/utils.py (79 lines of code) (raw):

import re import sys import unicodedata from lm_eval.filters.extraction import RegexFilter class MultiChoiceRegexFilter(RegexFilter): """ """ def __init__( self, regex_pattern: str = r"#### (\-?[0-9\.\,]+)", group_select=0, fallback: str = "[invalid]", ignore_case=False, ignore_punctuation=False, regexes_to_ignore=None, ) -> None: """ regex_pattern: The basic regex pattern to use. If fails to match, we will use the customized match procedure - step 1 : We parse the choices between ([A-Z])s then try to find these choices in the response. - step 2 : We parse the choice with regex :[\s]*([A-?]), where ? varies by number of choices. group_select: Selects the (group_select)th match from the findall result. ignore_case: Ignores the case during step 1 matching ignore_punctuation: Remove the punctuation during step 1 matching regexes_to_ignore: Remove these regexes during step 1 matching """ super().__init__(regex_pattern, group_select, fallback) self.ignore_case = ignore_case self.ignore_punctuation = ignore_punctuation self.regexes_to_ignore = regexes_to_ignore def apply(self, resps, docs): # here, we assume we have a list, in which each element is # a list of model responses for some particular input/target pair. # so we process each of these (same input/target response sets) # independently (and keep them a list.) def find_match(regex, resp, convert_dict={}): match = regex.findall(resp) if match: match = match[self.group_select] if isinstance(match, tuple): match = [m for m in match if m][0] match = match.strip() if match and match in convert_dict: match = convert_dict[match] return match punct_tbl = dict.fromkeys( i for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P") ) def filter_ignores(st): if self.regexes_to_ignore is not None: for s in self.regexes_to_ignore: st = re.sub(s, "", st) if self.ignore_case: st = st.lower() if self.ignore_punctuation: # https://stackoverflow.com/a/266162 st = st.translate(punct_tbl) return st filtered_resps = [] for r, doc in zip(resps, docs): fallback_regexes = [] choice_to_alpha = {} next_alpha = "A" without_paren_fallback_regexes = [] without_paren_to_target = {} choices = doc["choices"] for c in choices: m = filter_ignores(c.strip()) fallback_regexes.append(f"{re.escape(m)}") choice_to_alpha[m] = f"({next_alpha})" without_paren_fallback_regexes.append(next_alpha) without_paren_to_target[next_alpha] = f"({next_alpha})" next_alpha = chr(ord(next_alpha) + 1) fallback_regex = re.compile("|".join(fallback_regexes)) without_paren_fallback_regex = "|".join(without_paren_fallback_regexes) without_paren_fallback_regex = re.compile( f":[\s]*({without_paren_fallback_regex})" ) filtered = [] for resp in r: match = find_match(self.regex, resp) if not match: match = find_match( fallback_regex, filter_ignores(resp), choice_to_alpha ) if not match: match = find_match( without_paren_fallback_regex, resp, without_paren_to_target ) if not match: match = self.fallback filtered.append(match) filtered_resps.append(filtered) return filtered_resps