Source code for flagevalmm.evaluator.mmmu_dataset_evaluator

import re
from collections import defaultdict
from flagevalmm.registry import EVALUATORS
from flagevalmm.evaluator import BaseEvaluator


[docs] def check_is_number(string): """ Check if the given string a number. """ try: float(string.replace(",", "")) return True except ValueError: # check if there's comma inside return False
[docs] def normalize_str(string): """ Normalize the str to lower case and make them float numbers if possible. """ # check if characters in the string # if number, numerize it. string = string.strip() is_number = check_is_number(string) if is_number: string = string.replace(",", "") string = float(string) # leave 2 decimal string = round(string, 2) return [string] else: # it's likely to be a string # lower it string = string.lower() if len(string) == 1: return [" " + string, string + " "] # avoid trivial matches return [string]
[docs] def extract_numbers(string): """ Exact all forms of numbers from a string with regex. """ # Pattern for numbers with commas pattern_commas = r"-?\b\d{1,3}(?:,\d{3})+\b" # Pattern for scientific notation pattern_scientific = r"-?\d+(?:\.\d+)?[eE][+-]?\d+" # Pattern for simple numbers without commas pattern_simple = r"-?(?:\d+\.\d+|\.\d+|\d+\b)(?![eE][+-]?\d+)(?![,\d])" # Extract numbers with commas numbers_with_commas = re.findall(pattern_commas, string) # Extract numbers in scientific notation numbers_scientific = re.findall(pattern_scientific, string) # Extract simple numbers without commas numbers_simple = re.findall(pattern_simple, string) # Combine all extracted numbers all_numbers = numbers_with_commas + numbers_scientific + numbers_simple return all_numbers
[docs] def parse_open_response(response): """ Parse the prediction from the generated response. Return a list of predicted strings or numbers. """ # content = content.strip("\n").strip(".").strip(" ") def get_key_subresponses(response): key_responses = [] response = response.strip().strip(".").lower() sub_responses = re.split(r"\.\s(?=[A-Z])|\n", response) indicators_of_keys = [ "could be ", "so ", "is ", "thus ", "therefore ", "final ", "answer ", "result ", ] key_responses = [] for index, resp in enumerate(sub_responses): # if last one, accept it's an equation (the entire response can be just one sentence with equation) if index == len(sub_responses) - 1: indicators_of_keys.extend(["="]) shortest_key_response = None # the shortest response that may contain the answer (tail part of the response) for indicator in indicators_of_keys: if indicator in resp: if not shortest_key_response: shortest_key_response = resp.split(indicator)[-1].strip() else: if len(resp.split(indicator)[-1].strip()) < len( shortest_key_response ): shortest_key_response = resp.split(indicator)[-1].strip() if shortest_key_response: # and it's not trivial if shortest_key_response.strip() not in [ ":", ",", ".", "!", "?", ";", ":", "'", ]: key_responses.append(shortest_key_response) if len(key_responses) == 0: # did not found any return [response] return key_responses key_responses = get_key_subresponses(response) pred_list = key_responses.copy() # keep the original string response for resp in key_responses: pred_list.extend(extract_numbers(resp)) tmp_pred_list = [] for i in range(len(pred_list)): tmp_pred_list.extend(normalize_str(pred_list[i])) pred_list = tmp_pred_list # remove duplicates pred_list = list(set(pred_list)) return pred_list
[docs] def eval_open(gold_i, pred_i): """ Evaluate an open question instance """ correct = False if isinstance(gold_i, list): # use float to avoid trivial matches norm_answers = [] for answer in gold_i: norm_answers.extend(normalize_str(answer)) else: norm_answers = normalize_str(gold_i) for pred in pred_i: # pred is already normalized in parse response phase if isinstance(pred, str): # if it's a string, then find if ans in the pred_i for norm_ans in norm_answers: # only see if the string answer in the string pred if isinstance(norm_ans, str) and norm_ans in pred: if not correct: correct = True break else: # it's a float number if pred in norm_answers: if not correct: correct = True break return correct
[docs] @EVALUATORS.register_module() class MmmuEvaluator(BaseEvaluator): """ The evaluation method is adapted from the official MMMU benchmark evaluation code (https://github.com/MMMU-Benchmark/MMMU/tree/main/mmmu) with modifications to improve robustness and adapt to the flagevalmm framework. """
[docs] def cal_accuracy(self, annotation, answers): right = 0 subject_score = defaultdict(list) difficulty_score = defaultdict(list) for answer in answers: question_id = str(answer["question_id"]) gt = annotation[question_id] if ( gt["question_type"] == "multiple-choice" or gt["question_type"] == "vision" ): is_correct = self.evaluate_multiple_choice(gt, answer) else: pred = parse_open_response(answer["answer"]) is_correct = eval_open(gt["answer"], pred) right += is_correct subject_score[gt["subject"]].append(is_correct) if "topic_difficulty" in gt: difficulty_score[gt["topic_difficulty"]].append(is_correct) answer["correct"] = is_correct answer["label"] = gt["answer"] answer["question_type"] = gt["question_type"] result = {"accuracy": round(right / len(answers) * 100, 2)} result["subject_score"] = { k: round(sum(v) / len(v) * 100, 2) for k, v in subject_score.items() } if len(difficulty_score) > 0: result["difficulty_score"] = { k: round(sum(v) / len(v) * 100, 2) for k, v in difficulty_score.items() } return result