lmms_eval/tasks/mathvista/mathvista_evals.py (311 lines of code) (raw):

import time import requests import re from Levenshtein import distance import logging eval_logger = logging.getLogger("lmms-eval") DEMO_PROMPT = """ Please read the following example. Then extract the answer from the model response and type it at the end of the prompt. Hint: Please answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end. Question: Which number is missing? Model response: The number missing in the sequence is 14. Extracted answer: 14 Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value, e.g., 1.2, 1.3, 1.4, at the end. Question: What is the fraction of females facing the camera? Model response: The fraction of females facing the camera is 0.6, which means that six out of ten females in the group are facing the camera. Extracted answer: 0.6 Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value, e.g., 1.23, 1.34, 1.45, at the end. Question: How much money does Luca need to buy a sour apple candy and a butterscotch candy? (Unit: $) Model response: Luca needs $1.45 to buy a sour apple candy and a butterscotch candy. Extracted answer: 1.45 Hint: Please answer the question requiring a Python list as an answer and provide the final list, e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end. Question: Between which two years does the line graph saw its maximum peak? Model response: The line graph saw its maximum peak between 2007 and 2008. Extracted answer: [2007, 2008] Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end. Question: What fraction of the shape is blue?\nChoices:\n(A) 3/11\n(B) 8/11\n(C) 6/11\n(D) 3/5 Model response: The correct answer is (B) 8/11. Extracted answer: B """ class MathVistaEvaluator: API_URL = "https://api.openai.com/v1/chat/completions" def __init__(self, api_key, gpt_model="gpt-3.5-turbo", quick_extract=False): self.api_key = api_key self.gpt_model = gpt_model self.quick_extract = quick_extract def _post_request(self, payload): headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } response = requests.post(self.API_URL, headers=headers, json=payload, timeout=30) response.raise_for_status() return response.json() def get_chat_response(self, prompt, temperature=0, max_tokens=256, n=1, patience=10000000, sleep_time=0): messages = [ {"role": "user", "content": prompt}, ] payload = {"model": self.gpt_model, "messages": messages, "temperature": temperature, "max_tokens": max_tokens, "n": n} while patience > 0: patience -= 1 try: response = self._post_request(payload) if n == 1: prediction = response["choices"][0]["message"]["content"].strip() if prediction and prediction != "": return prediction else: prediction = [choice["message"]["content"].strip() for choice in response["choices"]] if prediction and prediction[0] != "": return prediction except Exception as e: if "Rate limit" not in str(e): eval_logger.error(e) if "Please reduce the length of the messages" in str(e): eval_logger.error("!!Reduce prompt size") # reduce input prompt and keep the tail new_size = int(len(prompt) * 0.9) new_start = len(prompt) - new_size prompt = prompt[new_start:] payload["messages"] = [ {"role": "user", "content": prompt}, ] if sleep_time > 0: time.sleep(sleep_time) return "" def verify_extraction(self, extraction): extraction = extraction.strip() if not extraction: return False return True def create_test_prompt(self, demo_prompt, query, response): demo_prompt = demo_prompt.strip() test_prompt = f"{query}\n\n{response}" full_prompt = f"{demo_prompt}\n\n{test_prompt}\n\nExtracted answer: " return full_prompt def extract_answer(self, response, problem, quick_extract=False): question_type = problem["question_type"] answer_type = problem["answer_type"] choices = problem.get("choices", []) query = problem["query"] if not response: return "" if question_type == "multi_choice" and response in choices: return response if answer_type == "integer": try: extraction = int(response) return str(extraction) except ValueError: pass if answer_type == "float": try: extraction = str(float(response)) return extraction except ValueError: pass # quick extraction if quick_extract: eval_logger.info("Quickly extracting answer...") # The answer is "text". -> "text" try: result = re.search(r'The answer is "(.*)"\.', response) if result: extraction = result.group(1) return extraction except re.error: pass # general extraction try: full_prompt = self.create_test_prompt(DEMO_PROMPT, query, response) extraction = self.get_chat_response(full_prompt, temperature=0, max_tokens=256, n=1) return extraction except Exception as e: eval_logger.error(e) eval_logger.error(f"Error in extracting answer for problem") return "" def get_most_similar(self, prediction, choices): """ Use the Levenshtein distance (or edit distance) to determine which of the choices is most similar to the given prediction """ distances = [distance(prediction, choice) for choice in choices] ind = distances.index(min(distances)) return choices[ind] def normalize_extracted_answer(self, extraction, choices, question_type, answer_type, precision): """ Normalize the extracted answer to match the answer type """ if question_type == "multi_choice": # make sure the extraction is a string if isinstance(extraction, str): extraction = extraction.strip() else: try: extraction = str(extraction) except: extraction = "" # extract "A" from "(A) text" letter = re.findall(r"\(([a-zA-Z])\)", extraction) if len(letter) > 0: extraction = letter[0].upper() options = [chr(ord("A") + i) for i in range(len(choices))] if extraction in options: # convert option letter to text, e.g. "A" -> "text" ind = options.index(extraction) extraction = choices[ind] else: # select the most similar option extraction = self.get_most_similar(extraction, choices) assert extraction in choices elif answer_type == "integer": try: extraction = str(int(float(extraction))) except: extraction = None elif answer_type == "float": try: extraction = str(round(float(extraction), precision)) except: extraction = None elif answer_type == "list": try: extraction = str(extraction) except: extraction = None return extraction def safe_equal(self, prediction, answer): """ Check if the prediction is equal to the answer, even if they are of different types """ try: if str(prediction).strip() == str(answer).strip(): return True return False except Exception as e: eval_logger.info(e) return False def get_acc_with_contion(self, res_pd, key, value): """ Calculate the accuracy of predictions with a specific condition """ if key == "skills": total_pd = res_pd[res_pd[key].apply(lambda x: value in x)] else: total_pd = res_pd[res_pd[key] == value] correct_pd = total_pd[total_pd["true_false"] == True] acc = "{:.2f}".format(len(correct_pd) / len(total_pd) * 100) if len(total_pd) > 0 else "0.00" return len(correct_pd), len(total_pd), acc def create_one_query(self, problem, shot_type, examples=None, shot_num=0, use_caption=False, use_ocr=False): ### [1] Demo prompt if shot_num == 0: demo_prompt = "" else: demos = [] shot_num = min(shot_num, len(examples)) for example in examples[:shot_num]: prompt = "" # question prompt += f"Question: {example['question']}" # choices if "choices" in example: texts = ["Choices:"] for i, choice in enumerate(example["choices"]): texts.append(f"({chr(ord('A')+i)}) {choice}") prompt += "\n" + "\n".join(texts) # caption if use_caption: caption = example["caption"] if "caption" in example else "" if caption != "": prompt += "\n" + f"Image description: {caption}" # ocr if use_ocr: ocr = example["ocr"] if "ocr" in example else "" if ocr != "": prompt += "\n" + f"Image detected text: {ocr}" # solution if shot_type == "solution": solution = example["solution"].strip() prompt += "\n" + f"Solution: {solution}" # step-by-step if shot_type == "step-by-step": solution = example["solution"].strip() prompt += "\n" + f"{solution}" # think-step-by-step if shot_type == "think-step-by-step": solution = example["solution"].strip() prompt += "\n" + f"{solution}" # direct if shot_type == "direct": solution = example["solution"].strip() prompt += "\n" + f"{solution}" # code if shot_type == "code": code = example["code"].strip() prompt += "\n" + f"Python code: {code}" demos.append(prompt) demo_prompt = "\n\n".join(demos) ### [2] Test query # problem info question = problem["question"] unit = problem["unit"] choices = problem["choices"] caption = problem["caption"] ocr = problem["ocr"] precision = problem["precision"] question_type = problem["question_type"] answer_type = problem["answer_type"] # hint if shot_type == "solution": if question_type == "multi_choice": assert answer_type == "text" hint_text = f"Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end." else: assert answer_type in ["integer", "float", "list"] if answer_type == "integer": hint_text = f"Hint: Please answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end." elif answer_type == "float" and precision == 1: hint_text = f"Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value, e.g., 1.2, 1.3, 1.4, at the end." elif answer_type == "float" and precision == 2: hint_text = f"Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value, e.g., 1.23, 1.34, 1.45, at the end." elif answer_type == "list": hint_text = f"Hint: Please answer the question requiring a Python list as an answer and provide the final list, e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end." # step-by-step elif shot_type == "format-prompt": if question_type == "multi_choice": assert answer_type == "text" hint_text = f"Answer with the option's letter from the given choices directly." else: if answer_type == "integer": hint_text = f"Answer the question using a single integer number." elif answer_type == "float" and precision == 1: hint_text = f"Answer the question using a single floating-point number with one decimal place." elif answer_type == "float" and precision == 2: hint_text = f"Answer the question using a single floating-point number with two decimal places." elif answer_type == "list": hint_text = f"Answer the question using a Python list." # step-by-step elif shot_type == "step-by-step": if question_type == "multi_choice": assert answer_type == "text" hint_text = f"Hint: Please answer the question and provide the correct option letter, e.g., A, B, C, D, at the end." else: assert answer_type in ["integer", "float", "list"] if answer_type == "integer": hint_text = f"Hint: Please answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end." elif answer_type == "float" and precision == 1: hint_text = f"Hint: Please answer the question requiring a floating-point number with one decimal place and provide the final value, e.g., 1.2, 1.3, 1.4, at the end." elif answer_type == "float" and precision == 2: hint_text = f"Hint: Please answer the question requiring a floating-point number with two decimal places and provide the final value, e.g., 1.23, 1.34, 1.45, at the end." elif answer_type == "list": hint_text = f"Hint: Please answer the question requiring a Python list as an answer and provide the final list, e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end." # step-by-step elif shot_type == "reason-first": if question_type == "multi_choice": assert answer_type == "text" hint_text = f"First perform reasoning, then finally select the question from the choices in the following format: Answer: xxx." else: assert answer_type in ["integer", "float", "list"] if answer_type == "integer": hint_text = f"First perform reasoning, then finally answer the question requiring an integer answer and provide the final value, e.g., 1, 2, 3, at the end in the following format: Answer: xxx." elif answer_type == "float" and precision == 1: hint_text = ( f"First perform reasoning, then finally answer the question requiring a floating-point number with one decimal place and provide the final value, e.g., 1.2, 1.3, 1.4, at the end in the following format: Answer: xxx." ) elif answer_type == "float" and precision == 2: hint_text = f"First perform reasoning, then finally answer the question requiring a floating-point number with two decimal places and provide the final value, e.g., 1.23, 1.34, 1.45, at the end in the following format: Answer: xxx." elif answer_type == "list": hint_text = f"First perform reasoning, then finally answer the question requiring a Python list as an answer and provide the final list, e.g., [1, 2, 3], [1.2, 1.3, 1.4], at the end in the following format: Answer: xxx." elif shot_type == "direct": hint_text = "" else: assert shot_type == "code" hint_text = "Hint: Please generate a python code to solve the problem" # question if shot_type == "format-prompt": question_text = f"{question}" else: question_text = f"Question: {question}" if unit: question_text += f" (Unit: {unit})" # choices if choices: if shot_type == "format-prompt": texts = [] for i, choice in enumerate(choices): texts.append(f"{chr(ord('A')+i)}. {choice}") choices_text = "\n".join(texts) else: # choices: (A) 1.2 (B) 1.3 (C) 1.4 (D) 1.5 texts = ["Choices:"] for i, choice in enumerate(choices): texts.append(f"({chr(ord('A')+i)}) {choice}") choices_text = "\n".join(texts) else: choices_text = "" # caption caption_text = "" if use_caption and caption != "": caption_text = f"Image description: {caption}" # ocr ocr_text = "" if use_ocr and ocr != "": ocr_text = f"Image detected text: {ocr}" # prompt if shot_type == "solution": prompt = "Solution: " elif shot_type == "format-prompt": prompt = "" elif shot_type == "step-by-step": prompt = "" elif shot_type == "reason-first": prompt = "" elif shot_type == "direct": prompt = "" else: assert shot_type == "code" prompt = "Python code: " if shot_type == "reason-first": elements = [hint_text, question_text, choices_text, caption_text, ocr_text, prompt] test_query = "\n".join([e for e in elements if e != ""]) else: elements = [question_text, choices_text, caption_text, ocr_text, hint_text, prompt] test_query = "\n".join([e for e in elements if e != ""]) ### [3] Final query query = demo_prompt + "\n\n" + test_query query = query.strip() return query