| import regex |
| from copy import deepcopy |
| from eval.eval_utils import math_equal |
| from eval.ocwcourses_eval_utils import ( |
| normalize_numeric, |
| numeric_equality, |
| normalize_symbolic_equation, |
| SymbolicMathMixin, |
| ) |
|
|
|
|
| def is_correct(item, pred_key="prediction", prec=1e-3): |
| pred = item[pred_key] |
| ans = item["answer"] |
| if isinstance(pred, list) and isinstance(ans, list): |
| pred_matched = set() |
| ans_matched = set() |
| for i in range(len(pred)): |
| for j in range(len(ans)): |
| item_cpy = deepcopy(item) |
| item_cpy.update({pred_key: pred[i], "answer": ans[j]}) |
| if is_correct(item_cpy, pred_key=pred_key, prec=prec): |
| pred_matched.add(i) |
| ans_matched.add(j) |
| if item_cpy[pred_key] == "2,3,4": |
| print(item, flush=True) |
| print("wtf", flush=True) |
| return len(pred_matched) == len(pred) and len(ans_matched) == len(ans) |
| elif isinstance(pred, str) and isinstance(ans, str): |
| if "\\cup" in pred and "\\cup" in ans: |
| item = deepcopy(item) |
| item.update( |
| { |
| pred_key: pred.split("\\cup"), |
| "answer": ans.split("\\cup"), |
| } |
| ) |
| return is_correct(item, pred_key=pred_key, prec=prec) |
| else: |
| label = False |
| try: |
| label = ( |
| abs( |
| float(regex.sub(r",", "", str(pred))) |
| - float(regex.sub(r",", "", str(ans))) |
| ) |
| < prec |
| ) |
| except: |
| pass |
| label = label or (ans and pred == ans) or math_equal(pred, ans) |
| return label |
| else: |
| print(item, flush=True) |
| raise NotImplementedError() |
|
|
|
|
| def eval_math(item, pred_key="prediction", prec=1e-3): |
| pred = item[pred_key] |
| if pred_key == "program_output" and isinstance(pred, str): |
| pred = [pred] |
| ans = item["answer"] |
| if isinstance(pred, list) and isinstance(ans, list): |
| |
| _ans = [] |
| for a in ans: |
| if a not in _ans: |
| _ans.append(a) |
| ans = _ans |
| |
| _pred = [] |
| for a in pred: |
| if a not in _pred: |
| _pred.append(a) |
| |
| pred = _pred[-len(ans) :] |
|
|
| item.update({pred_key: pred, "answer": ans}) |
| return is_correct(item, pred_key=pred_key, prec=prec) |
|
|
|
|
| def eval_last_single_answer(item, pred_key="prediction", prec=1e-3): |
| for key in [pred_key, "answer"]: |
| assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str" |
| return is_correct(item, pred_key=pred_key, prec=prec) |
|
|
|
|
| def eval_agieval_gaokao_math_cloze(item, pred_key="prediction", prec=1e-3): |
| if pred_key == "program_output" and isinstance(item[pred_key], str): |
| item[pred_key] = [item[pred_key]] |
| for key in [pred_key, "answer"]: |
| assert isinstance(item[key], list), f"{key} = `{item[key]}` is not a list" |
| pred = item[pred_key] |
| ans = item["answer"] |
| _pred = [] |
| for p in pred: |
| p = p + ";" |
| while p: |
| left_brackets = 0 |
| for i in range(len(p)): |
| if p[i] == ";" or (p[i] == "," and left_brackets == 0): |
| _p, p = p[:i].strip(), p[i + 1 :].strip() |
| if _p not in _pred: |
| _pred.append(_p) |
| break |
| elif p[i] in "([{": |
| left_brackets += 1 |
| elif p[i] in ")]}": |
| left_brackets -= 1 |
| pred = _pred[-len(ans) :] |
| if len(pred) == len(ans): |
| for p, a in zip(pred, ans): |
| item.update( |
| { |
| pred_key: p, |
| "answer": a, |
| } |
| ) |
| if not is_correct(item, pred_key=pred_key, prec=prec): |
| return False |
| return True |
| else: |
| return False |
|
|
|
|
| def eval_agieval_gaokao_mathqa(item, pred_key="prediction", prec=1e-3): |
| if pred_key == "program_output" and isinstance(item[pred_key], str): |
| item[pred_key] = [item[pred_key]] |
| pred_str = " ".join(item[pred_key]) |
| ans = item["answer"] |
| tag = None |
| idx = -1 |
| for t in "ABCD": |
| if t in pred_str and pred_str.index(t) > idx: |
| tag = t |
| idx = pred_str.index(t) |
| return tag == ans |
|
|
|
|
| def eval_math_sat(item, pred_key="prediction", prec=1e-3): |
| for key in [pred_key, "answer"]: |
| assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str" |
| return item[pred_key].lower() == item["answer"].lower() |
|
|
|
|
| def eval_mmlu_stem(item, pred_key="prediction", prec=1e-3): |
| return eval_math_sat(item, pred_key=pred_key, prec=prec) |
|
|
|
|
| def eval_ocwcourses(item, pred_key="prediction", prec=1e-3): |
| INVALID_ANSWER = "[invalidanswer]" |
| for key in [pred_key, "answer"]: |
| assert isinstance(item[key], str), f"{key} = `{item[key]}` is not a str" |
| pred = item[pred_key] |
| ans = item["answer"] |
|
|
| try: |
| float(ans) |
| normalize_fn = normalize_numeric |
| is_equiv = numeric_equality |
| answer_type = "numeric" |
| except ValueError: |
| if "=" in ans: |
| normalize_fn = normalize_symbolic_equation |
| is_equiv = lambda x, y: x == y |
| answer_type = "equation" |
| else: |
| normalize_fn = SymbolicMathMixin().normalize_tex |
| is_equiv = SymbolicMathMixin().is_tex_equiv |
| answer_type = "expression" |
|
|
| correct_answer = normalize_fn(ans) |
|
|
| unnormalized_answer = pred if pred else INVALID_ANSWER |
| model_answer = normalize_fn(unnormalized_answer) |
|
|
| if unnormalized_answer == INVALID_ANSWER: |
| acc = 0 |
| elif model_answer == INVALID_ANSWER: |
| acc = 0 |
| elif is_equiv(model_answer, correct_answer): |
| acc = 1 |
| else: |
| acc = 0 |
|
|
| return acc |
|
|
|
|
| def eval_minif2f_isabelle(item, pred_key="prediction", prec=1e-3): |
| return True |
|
|