vision/m4/evaluation/custom_metrics/utils.py (443 lines of code) (raw):
import re
import string
from m4.sourcing.data_collection.processors import FilteringFunctions
# VQA Normalization utils
ARTICLES = {"a", "an", "the"}
CONTRACTIONS = {
"aint": "ain't",
"arent": "aren't",
"cant": "can't",
"couldve": "could've",
"couldnt": "couldn't",
"couldn'tve": "couldn't've",
"couldnt've": "couldn't've",
"didnt": "didn't",
"doesnt": "doesn't",
"dont": "don't",
"hadnt": "hadn't",
"hadnt've": "hadn't've",
"hadn'tve": "hadn't've",
"hasnt": "hasn't",
"havent": "haven't",
"hed": "he'd",
"hed've": "he'd've",
"he'dve": "he'd've",
"hes": "he's",
"howd": "how'd",
"howll": "how'll",
"hows": "how's",
"Id've": "I'd've",
"I'dve": "I'd've",
"Im": "I'm",
"Ive": "I've",
"isnt": "isn't",
"itd": "it'd",
"itd've": "it'd've",
"it'dve": "it'd've",
"itll": "it'll",
"let's": "let's",
"maam": "ma'am",
"mightnt": "mightn't",
"mightnt've": "mightn't've",
"mightn'tve": "mightn't've",
"mightve": "might've",
"mustnt": "mustn't",
"mustve": "must've",
"neednt": "needn't",
"notve": "not've",
"oclock": "o'clock",
"oughtnt": "oughtn't",
"ow's'at": "'ow's'at",
"'ows'at": "'ow's'at",
"'ow'sat": "'ow's'at",
"shant": "shan't",
"shed've": "she'd've",
"she'dve": "she'd've",
"she's": "she's",
"shouldve": "should've",
"shouldnt": "shouldn't",
"shouldnt've": "shouldn't've",
"shouldn'tve": "shouldn't've",
"somebody'd": "somebodyd",
"somebodyd've": "somebody'd've",
"somebody'dve": "somebody'd've",
"somebodyll": "somebody'll",
"somebodys": "somebody's",
"someoned": "someone'd",
"someoned've": "someone'd've",
"someone'dve": "someone'd've",
"someonell": "someone'll",
"someones": "someone's",
"somethingd": "something'd",
"somethingd've": "something'd've",
"something'dve": "something'd've",
"somethingll": "something'll",
"thats": "that's",
"thered": "there'd",
"thered've": "there'd've",
"there'dve": "there'd've",
"therere": "there're",
"theres": "there's",
"theyd": "they'd",
"theyd've": "they'd've",
"they'dve": "they'd've",
"theyll": "they'll",
"theyre": "they're",
"theyve": "they've",
"twas": "'twas",
"wasnt": "wasn't",
"wed've": "we'd've",
"we'dve": "we'd've",
"weve": "we've",
"werent": "weren't",
"whatll": "what'll",
"whatre": "what're",
"whats": "what's",
"whatve": "what've",
"whens": "when's",
"whered": "where'd",
"wheres": "where's",
"whereve": "where've",
"whod": "who'd",
"whod've": "who'd've",
"who'dve": "who'd've",
"wholl": "who'll",
"whos": "who's",
"whove": "who've",
"whyll": "why'll",
"whyre": "why're",
"whys": "why's",
"wont": "won't",
"wouldve": "would've",
"wouldnt": "wouldn't",
"wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've",
"yall": "y'all",
"yall'll": "y'all'll",
"y'allll": "y'all'll",
"yall'd've": "y'all'd've",
"y'alld've": "y'all'd've",
"y'all'dve": "y'all'd've",
"youd": "you'd",
"youd've": "you'd've",
"you'dve": "you'd've",
"youll": "you'll",
"youre": "you're",
"youve": "you've",
}
NUMBERS_STRING_TO_INT = {
"none": "0",
"zero": "0",
"one": "1",
"two": "2",
"three": "3",
"four": "4",
"five": "5",
"six": "6",
"seven": "7",
"eight": "8",
"nine": "9",
"ten": "10",
}
def vqa_normalize_text(text: str) -> str:
"""Process a text
Source: https://github.com/GT-Vision-Lab/VQA/blob/master/PythonEvaluationTools/vqaEvaluation/vqaEval.py
1. Conversion of characters to lower case
2. Replace breaking lines and tabulations by a white space
3. Replace punctuations by a white space
4. Conversion of numbers written in letters to digits
5. Standardize contractions
6. Remove articles
7. Remove consecutive white spaces and strip the text
"""
text = text.lower()
text = FilteringFunctions.standardize_whitespace(text)
text = text.replace("\n", " ")
text = text.replace("\t", " ")
text = text.replace("-", "")
trans_remove_punctuation = str.maketrans(dict.fromkeys(string.punctuation.replace("'", ""), " "))
text = text.translate(trans_remove_punctuation)
words = text.split(" ")
for idx, word in enumerate(words):
if word in NUMBERS_STRING_TO_INT:
words[idx] = NUMBERS_STRING_TO_INT[word]
elif word in CONTRACTIONS:
words[idx] = CONTRACTIONS[word]
elif word in ARTICLES:
words[idx] = ""
text = " ".join([word for word in words if word])
text = text.strip()
return text
# Exact replication of the normalization at https://github.com/GT-Vision-Lab/VQA/blob/master/PythonEvaluationTools/vqaEvaluation/vqaEval.py
class VQANormalizationGtVisionLab:
def __init__(self):
self.contractions = {
"aint": "ain't",
"arent": "aren't",
"cant": "can't",
"couldve": "could've",
"couldnt": "couldn't",
"couldn'tve": "couldn't've",
"couldnt've": "couldn't've",
"didnt": "didn't",
"doesnt": "doesn't",
"dont": "don't",
"hadnt": "hadn't",
"hadnt've": "hadn't've",
"hadn'tve": "hadn't've",
"hasnt": "hasn't",
"havent": "haven't",
"hed": "he'd",
"hed've": "he'd've",
"he'dve": "he'd've",
"hes": "he's",
"howd": "how'd",
"howll": "how'll",
"hows": "how's",
"Id've": "I'd've",
"I'dve": "I'd've",
"Im": "I'm",
"Ive": "I've",
"isnt": "isn't",
"itd": "it'd",
"itd've": "it'd've",
"it'dve": "it'd've",
"itll": "it'll",
"let's": "let's",
"maam": "ma'am",
"mightnt": "mightn't",
"mightnt've": "mightn't've",
"mightn'tve": "mightn't've",
"mightve": "might've",
"mustnt": "mustn't",
"mustve": "must've",
"neednt": "needn't",
"notve": "not've",
"oclock": "o'clock",
"oughtnt": "oughtn't",
"ow's'at": "'ow's'at",
"'ows'at": "'ow's'at",
"'ow'sat": "'ow's'at",
"shant": "shan't",
"shed've": "she'd've",
"she'dve": "she'd've",
"she's": "she's",
"shouldve": "should've",
"shouldnt": "shouldn't",
"shouldnt've": "shouldn't've",
"shouldn'tve": "shouldn't've",
"somebody'd": "somebodyd",
"somebodyd've": "somebody'd've",
"somebody'dve": "somebody'd've",
"somebodyll": "somebody'll",
"somebodys": "somebody's",
"someoned": "someone'd",
"someoned've": "someone'd've",
"someone'dve": "someone'd've",
"someonell": "someone'll",
"someones": "someone's",
"somethingd": "something'd",
"somethingd've": "something'd've",
"something'dve": "something'd've",
"somethingll": "something'll",
"thats": "that's",
"thered": "there'd",
"thered've": "there'd've",
"there'dve": "there'd've",
"therere": "there're",
"theres": "there's",
"theyd": "they'd",
"theyd've": "they'd've",
"they'dve": "they'd've",
"theyll": "they'll",
"theyre": "they're",
"theyve": "they've",
"twas": "'twas",
"wasnt": "wasn't",
"wed've": "we'd've",
"we'dve": "we'd've",
"weve": "we've",
"werent": "weren't",
"whatll": "what'll",
"whatre": "what're",
"whats": "what's",
"whatve": "what've",
"whens": "when's",
"whered": "where'd",
"wheres": "where's",
"whereve": "where've",
"whod": "who'd",
"whod've": "who'd've",
"who'dve": "who'd've",
"wholl": "who'll",
"whos": "who's",
"whove": "who've",
"whyll": "why'll",
"whyre": "why're",
"whys": "why's",
"wont": "won't",
"wouldve": "would've",
"wouldnt": "wouldn't",
"wouldnt've": "wouldn't've",
"wouldn'tve": "wouldn't've",
"yall": "y'all",
"yall'll": "y'all'll",
"y'allll": "y'all'll",
"yall'd've": "y'all'd've",
"y'alld've": "y'all'd've",
"y'all'dve": "y'all'd've",
"youd": "you'd",
"youd've": "you'd've",
"you'dve": "you'd've",
"youll": "you'll",
"youre": "you're",
"youve": "you've",
}
self.manual_map = {
"none": "0",
"zero": "0",
"one": "1",
"two": "2",
"three": "3",
"four": "4",
"five": "5",
"six": "6",
"seven": "7",
"eight": "8",
"nine": "9",
"ten": "10",
}
self.articles = ["a", "an", "the"]
self.period_strip = re.compile("(?!<=\d)(\.)(?!\d)")
self.comma_strip = re.compile("(\d)(\,)(\d)")
self.punct = [
";",
r"/",
"[",
"]",
'"',
"{",
"}",
"(",
")",
"=",
"+",
"\\",
"_",
"-",
">",
"<",
"@",
"`",
",",
"?",
"!",
]
def processPunctuation(self, in_text):
out_text = in_text
for p in self.punct:
if (p + " " in in_text or " " + p in in_text) or (re.search(self.comma_strip, in_text) is not None):
out_text = out_text.replace(p, "")
else:
out_text = out_text.replace(p, " ")
out_text = self.period_strip.sub("", out_text, re.UNICODE)
return out_text
def processDigitArticle(self, in_text):
out_text = []
tempText = in_text.lower().split()
for word in tempText:
word = self.manual_map.setdefault(word, word)
if word not in self.articles:
out_text.append(word)
else:
pass
for wordId, word in enumerate(out_text):
if word in self.contractions:
out_text[wordId] = self.contractions[word]
out_text = " ".join(out_text)
return out_text
def vqa_normalize_text(self, text):
text = text.replace("\n", " ")
text = text.replace("\t", " ")
text = text.strip()
text = self.processPunctuation(text)
text = self.processDigitArticle(text)
return text
NUMBER_WORD_TO_NUMBER_INT = {
"one": "1",
"two": "2",
"three": "3",
"four": "4",
"five": "5",
"six": "6",
"seven": "7",
"eight": "8",
"nine": "9",
"zero": "0",
}
# Copy pasted and adapted from https://github.com/MMMU-Benchmark/MMMU/blob/main/eval/utils/eval_utils.py#L65
def convert_to_number(string):
string = string.lower()
for k, v in NUMBER_WORD_TO_NUMBER_INT.items():
string = string.replace(k, v)
return float(string.replace(",", ""))
def check_is_number(string):
"""
Check if the given string is a number
"""
try:
_ = convert_to_number(string)
return True
except ValueError:
return False
# Taken and modified from https://github.com/MMMU-Benchmark/MMMU/blob/main/eval/utils/eval_utils.py#L76
def normalize_str_mmmu(string):
"""
Normalize the str to lower case and make them float numbers if possible.
"""
# check if characters in the string
# if number, numerize it.
string = string.strip()
if string.startswith("Answer: "):
string = string.replace("Answer: ", "")
is_number = check_is_number(string)
if is_number:
string = convert_to_number(string)
# leave 2 decimal
string = round(string, 2)
return string
else:
string = string.lower()
return string
def extract_numbers_mmmu(string):
"""
Exact all forms of numbers from a string with regex.
"""
# Pattern for numbers with commas
pattern_commas = r"-?\b\d{1,3}(?:,\d{3})+\b"
# Pattern for scientific notation
pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+"
# Pattern for simple numbers without commas
pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])"
# Extract numbers with commas
numbers_with_commas = re.findall(pattern_commas, string)
# Extract numbers in scientific notation
numbers_scientific = re.findall(pattern_scientific, string)
# Extract simple numbers without commas
numbers_simple = re.findall(pattern_simple, string)
# Combine all extracted numbers
all_numbers = numbers_with_commas + numbers_scientific + numbers_simple
return all_numbers
# Cpoied from https://github.com/MMMU-Benchmark/MMMU/blob/36adc047118013d66c225ebfa352ebbf44740ce4/eval/utils/eval_utils.py#L122
# Except added "answer: " as key indicator
def parse_open_response_mmmu(response, normalize_text_fn):
"""
Parse the prediction from the generated response.
Return a list of predicted strings or numbers
"""
def get_key_subresponses(response):
key_responses = []
response = response.strip().strip(".").lower()
sub_responses = re.split(r"\.\s(?=[A-Z])|\n", response)
indicators_of_keys = [
"could be ",
"so ",
"is ",
"thus ",
"therefore ",
"final ",
"answer ",
"result ",
"answer: ",
]
key_responses = []
for index, resp in enumerate(sub_responses):
# if last one, accept it's an equation (the entire response can be just one sentence with equation)
if index == len(sub_responses) - 1:
indicators_of_keys.extend(["="])
shortest_key_response = (
None # the shortest response that may contain the answer (tail part of the response)
)
for indicator in indicators_of_keys:
if indicator in resp:
if not shortest_key_response:
shortest_key_response = resp.split(indicator)[-1].strip()
else:
if len(resp.split(indicator)[-1].strip()) < len(shortest_key_response):
shortest_key_response = resp.split(indicator)[-1].strip()
if shortest_key_response:
# and it's not trivial
if shortest_key_response.strip() not in [":", ",", ".", "!", "?", ";", ":", "'"]:
key_responses.append(shortest_key_response)
if len(key_responses) == 0: # did not found any
return [response]
return key_responses
key_responses = get_key_subresponses(response)
pred_list = key_responses.copy() # keep the original string response
for resp in key_responses:
pred_list.extend(extract_numbers_mmmu(resp))
tmp_pred_list = []
for i in range(len(pred_list)):
# append instead of extend as we return a string and not a list
tmp_pred_list.append(normalize_text_fn(pred_list[i]))
pred_list = tmp_pred_list
# remove duplicates
pred_list = list(set(pred_list))
return pred_list