vision/m4/evaluation/custom_metrics/utils.py (443 lines of code) (raw):

import re import string from m4.sourcing.data_collection.processors import FilteringFunctions # VQA Normalization utils ARTICLES = {"a", "an", "the"} CONTRACTIONS = { "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll": "you'll", "youre": "you're", "youve": "you've", } NUMBERS_STRING_TO_INT = { "none": "0", "zero": "0", "one": "1", "two": "2", "three": "3", "four": "4", "five": "5", "six": "6", "seven": "7", "eight": "8", "nine": "9", "ten": "10", } def vqa_normalize_text(text: str) -> str: """Process a text Source: https://github.com/GT-Vision-Lab/VQA/blob/master/PythonEvaluationTools/vqaEvaluation/vqaEval.py 1. Conversion of characters to lower case 2. Replace breaking lines and tabulations by a white space 3. Replace punctuations by a white space 4. Conversion of numbers written in letters to digits 5. Standardize contractions 6. Remove articles 7. Remove consecutive white spaces and strip the text """ text = text.lower() text = FilteringFunctions.standardize_whitespace(text) text = text.replace("\n", " ") text = text.replace("\t", " ") text = text.replace("-", "") trans_remove_punctuation = str.maketrans(dict.fromkeys(string.punctuation.replace("'", ""), " ")) text = text.translate(trans_remove_punctuation) words = text.split(" ") for idx, word in enumerate(words): if word in NUMBERS_STRING_TO_INT: words[idx] = NUMBERS_STRING_TO_INT[word] elif word in CONTRACTIONS: words[idx] = CONTRACTIONS[word] elif word in ARTICLES: words[idx] = "" text = " ".join([word for word in words if word]) text = text.strip() return text # Exact replication of the normalization at https://github.com/GT-Vision-Lab/VQA/blob/master/PythonEvaluationTools/vqaEvaluation/vqaEval.py class VQANormalizationGtVisionLab: def __init__(self): self.contractions = { "aint": "ain't", "arent": "aren't", "cant": "can't", "couldve": "could've", "couldnt": "couldn't", "couldn'tve": "couldn't've", "couldnt've": "couldn't've", "didnt": "didn't", "doesnt": "doesn't", "dont": "don't", "hadnt": "hadn't", "hadnt've": "hadn't've", "hadn'tve": "hadn't've", "hasnt": "hasn't", "havent": "haven't", "hed": "he'd", "hed've": "he'd've", "he'dve": "he'd've", "hes": "he's", "howd": "how'd", "howll": "how'll", "hows": "how's", "Id've": "I'd've", "I'dve": "I'd've", "Im": "I'm", "Ive": "I've", "isnt": "isn't", "itd": "it'd", "itd've": "it'd've", "it'dve": "it'd've", "itll": "it'll", "let's": "let's", "maam": "ma'am", "mightnt": "mightn't", "mightnt've": "mightn't've", "mightn'tve": "mightn't've", "mightve": "might've", "mustnt": "mustn't", "mustve": "must've", "neednt": "needn't", "notve": "not've", "oclock": "o'clock", "oughtnt": "oughtn't", "ow's'at": "'ow's'at", "'ows'at": "'ow's'at", "'ow'sat": "'ow's'at", "shant": "shan't", "shed've": "she'd've", "she'dve": "she'd've", "she's": "she's", "shouldve": "should've", "shouldnt": "shouldn't", "shouldnt've": "shouldn't've", "shouldn'tve": "shouldn't've", "somebody'd": "somebodyd", "somebodyd've": "somebody'd've", "somebody'dve": "somebody'd've", "somebodyll": "somebody'll", "somebodys": "somebody's", "someoned": "someone'd", "someoned've": "someone'd've", "someone'dve": "someone'd've", "someonell": "someone'll", "someones": "someone's", "somethingd": "something'd", "somethingd've": "something'd've", "something'dve": "something'd've", "somethingll": "something'll", "thats": "that's", "thered": "there'd", "thered've": "there'd've", "there'dve": "there'd've", "therere": "there're", "theres": "there's", "theyd": "they'd", "theyd've": "they'd've", "they'dve": "they'd've", "theyll": "they'll", "theyre": "they're", "theyve": "they've", "twas": "'twas", "wasnt": "wasn't", "wed've": "we'd've", "we'dve": "we'd've", "weve": "we've", "werent": "weren't", "whatll": "what'll", "whatre": "what're", "whats": "what's", "whatve": "what've", "whens": "when's", "whered": "where'd", "wheres": "where's", "whereve": "where've", "whod": "who'd", "whod've": "who'd've", "who'dve": "who'd've", "wholl": "who'll", "whos": "who's", "whove": "who've", "whyll": "why'll", "whyre": "why're", "whys": "why's", "wont": "won't", "wouldve": "would've", "wouldnt": "wouldn't", "wouldnt've": "wouldn't've", "wouldn'tve": "wouldn't've", "yall": "y'all", "yall'll": "y'all'll", "y'allll": "y'all'll", "yall'd've": "y'all'd've", "y'alld've": "y'all'd've", "y'all'dve": "y'all'd've", "youd": "you'd", "youd've": "you'd've", "you'dve": "you'd've", "youll": "you'll", "youre": "you're", "youve": "you've", } self.manual_map = { "none": "0", "zero": "0", "one": "1", "two": "2", "three": "3", "four": "4", "five": "5", "six": "6", "seven": "7", "eight": "8", "nine": "9", "ten": "10", } self.articles = ["a", "an", "the"] self.period_strip = re.compile("(?!<=\d)(\.)(?!\d)") self.comma_strip = re.compile("(\d)(\,)(\d)") self.punct = [ ";", r"/", "[", "]", '"', "{", "}", "(", ")", "=", "+", "\\", "_", "-", ">", "<", "@", "`", ",", "?", "!", ] def processPunctuation(self, in_text): out_text = in_text for p in self.punct: if (p + " " in in_text or " " + p in in_text) or (re.search(self.comma_strip, in_text) is not None): out_text = out_text.replace(p, "") else: out_text = out_text.replace(p, " ") out_text = self.period_strip.sub("", out_text, re.UNICODE) return out_text def processDigitArticle(self, in_text): out_text = [] tempText = in_text.lower().split() for word in tempText: word = self.manual_map.setdefault(word, word) if word not in self.articles: out_text.append(word) else: pass for wordId, word in enumerate(out_text): if word in self.contractions: out_text[wordId] = self.contractions[word] out_text = " ".join(out_text) return out_text def vqa_normalize_text(self, text): text = text.replace("\n", " ") text = text.replace("\t", " ") text = text.strip() text = self.processPunctuation(text) text = self.processDigitArticle(text) return text NUMBER_WORD_TO_NUMBER_INT = { "one": "1", "two": "2", "three": "3", "four": "4", "five": "5", "six": "6", "seven": "7", "eight": "8", "nine": "9", "zero": "0", } # Copy pasted and adapted from https://github.com/MMMU-Benchmark/MMMU/blob/main/eval/utils/eval_utils.py#L65 def convert_to_number(string): string = string.lower() for k, v in NUMBER_WORD_TO_NUMBER_INT.items(): string = string.replace(k, v) return float(string.replace(",", "")) def check_is_number(string): """ Check if the given string is a number """ try: _ = convert_to_number(string) return True except ValueError: return False # Taken and modified from https://github.com/MMMU-Benchmark/MMMU/blob/main/eval/utils/eval_utils.py#L76 def normalize_str_mmmu(string): """ Normalize the str to lower case and make them float numbers if possible. """ # check if characters in the string # if number, numerize it. string = string.strip() if string.startswith("Answer: "): string = string.replace("Answer: ", "") is_number = check_is_number(string) if is_number: string = convert_to_number(string) # leave 2 decimal string = round(string, 2) return string else: string = string.lower() return string def extract_numbers_mmmu(string): """ Exact all forms of numbers from a string with regex. """ # Pattern for numbers with commas pattern_commas = r"-?\b\d{1,3}(?:,\d{3})+\b" # Pattern for scientific notation pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+" # Pattern for simple numbers without commas pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])" # Extract numbers with commas numbers_with_commas = re.findall(pattern_commas, string) # Extract numbers in scientific notation numbers_scientific = re.findall(pattern_scientific, string) # Extract simple numbers without commas numbers_simple = re.findall(pattern_simple, string) # Combine all extracted numbers all_numbers = numbers_with_commas + numbers_scientific + numbers_simple return all_numbers # Cpoied from https://github.com/MMMU-Benchmark/MMMU/blob/36adc047118013d66c225ebfa352ebbf44740ce4/eval/utils/eval_utils.py#L122 # Except added "answer: " as key indicator def parse_open_response_mmmu(response, normalize_text_fn): """ Parse the prediction from the generated response. Return a list of predicted strings or numbers """ def get_key_subresponses(response): key_responses = [] response = response.strip().strip(".").lower() sub_responses = re.split(r"\.\s(?=[A-Z])|\n", response) indicators_of_keys = [ "could be ", "so ", "is ", "thus ", "therefore ", "final ", "answer ", "result ", "answer: ", ] key_responses = [] for index, resp in enumerate(sub_responses): # if last one, accept it's an equation (the entire response can be just one sentence with equation) if index == len(sub_responses) - 1: indicators_of_keys.extend(["="]) shortest_key_response = ( None # the shortest response that may contain the answer (tail part of the response) ) for indicator in indicators_of_keys: if indicator in resp: if not shortest_key_response: shortest_key_response = resp.split(indicator)[-1].strip() else: if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response): shortest_key_response = resp.split(indicator)[-1].strip() if shortest_key_response: # and it's not trivial if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]: key_responses.append(shortest_key_response) if len(key_responses) == 0: # did not found any return [response] return key_responses key_responses = get_key_subresponses(response) pred_list = key_responses.copy() # keep the original string response for resp in key_responses: pred_list.extend(extract_numbers_mmmu(resp)) tmp_pred_list = [] for i in range(len(pred_list)): # append instead of extend as we return a string and not a list tmp_pred_list.append(normalize_text_fn(pred_list[i])) pred_list = tmp_pred_list # remove duplicates pred_list = list(set(pred_list)) return pred_list