Source code for flagevalmm.evaluator.mmmu_dataset_evaluator
import re
from collections import defaultdict
from flagevalmm.registry import EVALUATORS
from flagevalmm.evaluator import BaseEvaluator
[docs]
def check_is_number(string):
"""
Check if the given string a number.
"""
try:
float(string.replace(",", ""))
return True
except ValueError:
# check if there's comma inside
return False
[docs]
def normalize_str(string):
"""
Normalize the str to lower case and make them float numbers if possible.
"""
# check if characters in the string
# if number, numerize it.
string = string.strip()
is_number = check_is_number(string)
if is_number:
string = string.replace(",", "")
string = float(string)
# leave 2 decimal
string = round(string, 2)
return [string]
else: # it's likely to be a string
# lower it
string = string.lower()
if len(string) == 1:
return [" " + string, string + " "] # avoid trivial matches
return [string]
[docs]
def parse_open_response(response):
"""
Parse the prediction from the generated response.
Return a list of predicted strings or numbers.
"""
# content = content.strip("\n").strip(".").strip(" ")
def get_key_subresponses(response):
key_responses = []
response = response.strip().strip(".").lower()
sub_responses = re.split(r"\.\s(?=[A-Z])|\n", response)
indicators_of_keys = [
"could be ",
"so ",
"is ",
"thus ",
"therefore ",
"final ",
"answer ",
"result ",
]
key_responses = []
for index, resp in enumerate(sub_responses):
# if last one, accept it's an equation (the entire response can be just one sentence with equation)
if index == len(sub_responses) - 1:
indicators_of_keys.extend(["="])
shortest_key_response = None # the shortest response that may contain the answer (tail part of the response)
for indicator in indicators_of_keys:
if indicator in resp:
if not shortest_key_response:
shortest_key_response = resp.split(indicator)[-1].strip()
else:
if len(resp.split(indicator)[-1].strip()) < len(
shortest_key_response
):
shortest_key_response = resp.split(indicator)[-1].strip()
if shortest_key_response:
# and it's not trivial
if shortest_key_response.strip() not in [
":",
",",
".",
"!",
"?",
";",
":",
"'",
]:
key_responses.append(shortest_key_response)
if len(key_responses) == 0: # did not found any
return [response]
return key_responses
key_responses = get_key_subresponses(response)
pred_list = key_responses.copy() # keep the original string response
for resp in key_responses:
pred_list.extend(extract_numbers(resp))
tmp_pred_list = []
for i in range(len(pred_list)):
tmp_pred_list.extend(normalize_str(pred_list[i]))
pred_list = tmp_pred_list
# remove duplicates
pred_list = list(set(pred_list))
return pred_list
[docs]
def eval_open(gold_i, pred_i):
"""
Evaluate an open question instance
"""
correct = False
if isinstance(gold_i, list):
# use float to avoid trivial matches
norm_answers = []
for answer in gold_i:
norm_answers.extend(normalize_str(answer))
else:
norm_answers = normalize_str(gold_i)
for pred in pred_i: # pred is already normalized in parse response phase
if isinstance(pred, str): # if it's a string, then find if ans in the pred_i
for norm_ans in norm_answers:
# only see if the string answer in the string pred
if isinstance(norm_ans, str) and norm_ans in pred:
if not correct:
correct = True
break
else: # it's a float number
if pred in norm_answers:
if not correct:
correct = True
break
return correct
[docs]
@EVALUATORS.register_module()
class MmmuEvaluator(BaseEvaluator):
"""
The evaluation method is adapted from the official MMMU benchmark evaluation code
(https://github.com/MMMU-Benchmark/MMMU/tree/main/mmmu) with modifications to improve
robustness and adapt to the flagevalmm framework.
"""
[docs]
def cal_accuracy(self, annotation, answers):
right = 0
subject_score = defaultdict(list)
difficulty_score = defaultdict(list)
for answer in answers:
question_id = str(answer["question_id"])
gt = annotation[question_id]
if (
gt["question_type"] == "multiple-choice"
or gt["question_type"] == "vision"
):
is_correct = self.evaluate_multiple_choice(gt, answer)
else:
pred = parse_open_response(answer["answer"])
is_correct = eval_open(gt["answer"], pred)
right += is_correct
subject_score[gt["subject"]].append(is_correct)
if "topic_difficulty" in gt:
difficulty_score[gt["topic_difficulty"]].append(is_correct)
answer["correct"] = is_correct
answer["label"] = gt["answer"]
answer["question_type"] = gt["question_type"]
result = {"accuracy": round(right / len(answers) * 100, 2)}
result["subject_score"] = {
k: round(sum(v) / len(v) * 100, 2) for k, v in subject_score.items()
}
if len(difficulty_score) > 0:
result["difficulty_score"] = {
k: round(sum(v) / len(v) * 100, 2) for k, v in difficulty_score.items()
}
return result