drop_eval.py (220 lines of code) (raw):

""" DROP: A Reading Comprehension Benchmark Requiring Discrete Reasoning Over Paragraphs Dheeru Dua, Yizhong Wang, Pradeep Dasigi, Gabriel Stanovsky, Sameer Singh, Matt Gardner https://arxiv.org/abs/1903.00161 """ import gzip import json import random import re import string from typing import Any, Dict, List, Optional, Set, Tuple, Union import numpy as np from scipy.optimize import linear_sum_assignment from . import common from .common import ANSWER_PATTERN, HTML_JINJA from .types import Eval, EvalResult, SamplerBase, SingleEvalResult """ From here through _normalize_answer was originally copied from: https://worksheets.codalab.org/rest/bundles/0x6b567e1cf2e041ec80d7098f031c5c9e/contents/blob/ Then cleaned up and modified a bit. The rest was originally copied from https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc /eval/drop_eval.py """ def _remove_articles(text: str) -> str: regex = re.compile(r"\b(a|an|the)\b", re.UNICODE) return re.sub(regex, " ", text) def _white_space_fix(text: str) -> str: return " ".join(text.split()) EXCLUDE = set(string.punctuation) def _remove_punc(text: str) -> str: if not _is_number(text): return "".join(ch for ch in text if ch not in EXCLUDE) else: return text def _lower(text: str) -> str: return text.lower() def _tokenize(text: str) -> List[str]: return re.split(" |-", text) def _normalize_answer(text: str) -> str: """Lower text and remove punctuation, articles and extra whitespace.""" parts = [ _white_space_fix(_remove_articles(_normalize_number(_remove_punc(_lower(token))))) for token in _tokenize(text) ] parts = [part for part in parts if part.strip()] normalized = " ".join(parts).strip() return normalized def _is_number(text: str) -> bool: try: float(text) return True except ValueError: return False def _normalize_number(text: str) -> str: if _is_number(text): return str(float(text)) else: return text def _answer_to_bags( answer: Union[str, List[str], Tuple[str, ...]] ) -> Tuple[List[str], List[Set[str]]]: if isinstance(answer, (list, tuple)): raw_spans = answer else: raw_spans = [answer] normalized_spans: List[str] = [] token_bags = [] for raw_span in raw_spans: normalized_span = _normalize_answer(raw_span) normalized_spans.append(normalized_span) token_bags.append(set(normalized_span.split())) return normalized_spans, token_bags def _align_bags(predicted: List[Set[str]], gold: List[Set[str]]) -> List[float]: """ Takes gold and predicted answer sets and first finds the optimal 1-1 alignment between them and gets maximum metric values over all the answers. """ scores = np.zeros([len(gold), len(predicted)]) for gold_index, gold_item in enumerate(gold): for pred_index, pred_item in enumerate(predicted): if _match_numbers_if_present(gold_item, pred_item): scores[gold_index, pred_index] = _compute_f1(pred_item, gold_item) row_ind, col_ind = linear_sum_assignment(-scores) max_scores = np.zeros([max(len(gold), len(predicted))]) for row, column in zip(row_ind, col_ind): max_scores[row] = max(max_scores[row], scores[row, column]) return max_scores def _compute_f1(predicted_bag: Set[str], gold_bag: Set[str]) -> float: intersection = len(gold_bag.intersection(predicted_bag)) if not predicted_bag: precision = 1.0 else: precision = intersection / float(len(predicted_bag)) if not gold_bag: recall = 1.0 else: recall = intersection / float(len(gold_bag)) f1 = ( (2 * precision * recall) / (precision + recall) if not (precision == 0.0 and recall == 0.0) else 0.0 ) * 100 return f1 def _match_numbers_if_present(gold_bag: Set[str], predicted_bag: Set[str]) -> bool: gold_numbers = set() predicted_numbers = set() for word in gold_bag: if _is_number(word): gold_numbers.add(word) for word in predicted_bag: if _is_number(word): predicted_numbers.add(word) if (not gold_numbers) or gold_numbers.intersection(predicted_numbers): return True return False def get_drop_metrics( predicted: Union[str, List[str], Tuple[str, ...]], gold: Union[str, List[str], Tuple[str, ...]] ) -> Tuple[float, float]: """ Takes a predicted answer and a gold answer (that are both either a string or a list of strings), and returns exact match and the DROP F1 metric for the prediction. If you are writing a script for evaluating objects in memory (say, the output of predictions during validation, or while training), this is the function you want to call, after using :func:`answer_json_to_strings` when reading the gold answer from the released data file. """ predicted_bags = _answer_to_bags(predicted) gold_bags = _answer_to_bags(gold) if set(predicted_bags[0]) == set(gold_bags[0]) and len(predicted_bags[0]) == len(gold_bags[0]): exact_match = 1.0 else: exact_match = 0.0 f1_per_bag = _align_bags(predicted_bags[1], gold_bags[1]) f1 = np.mean(f1_per_bag) f1 = round(f1, 2) return exact_match, f1 def answer_json_to_strings(answer: Dict[str, Any]) -> Tuple[Tuple[str, ...], str]: """ Takes an answer JSON blob from the DROP data release and converts it into strings used for evaluation. """ if "number" in answer and answer["number"]: return tuple([str(answer["number"])]), "number" elif "spans" in answer and answer["spans"]: return tuple(answer["spans"]), "span" if len(answer["spans"]) == 1 else "spans" elif "date" in answer: return ( tuple( [ "{0} {1} {2}".format( answer["date"]["day"], answer["date"]["month"], answer["date"]["year"] ).strip() ] ), "date", ) else: raise ValueError( f"Answer type not found, should be one of number, spans or date at: {json.dumps(answer)}" ) def answer_json_to_string(answer_json): return json.dumps(answer_json_to_strings(answer_json)) def normalize(s: str) -> str: """Lower text and remove punctuation, articles and extra whitespace.""" s = s.lower() exclude = set(string.punctuation) s = "".join(char for char in s if char not in exclude) s = re.sub(r"\b(a|an|the)\b", " ", s) s = " ".join(s.split()) return s def fuzzy_match(s1: str, s2: str) -> bool: s1 = normalize(s1) s2 = normalize(s2) if s1 == "" or s2 == "": return s1 == s2 return s1 in s2 or s2 in s1 def drop_metric(sample: str, reference: list[str]) -> Tuple[float, float]: em_scores = [] f1_scores = [] for answer in reference: if answer.strip() != "": em, f1 = get_drop_metrics(sample, answer) em_scores.append(em) f1_scores.append(f1) return (max(em_scores), max(f1_scores)) class DropEval(Eval): def __init__(self, num_examples: int | None = None, train_samples_per_prompt: int = 3): self.seed = 42 self._num_examples = num_examples self._train_samples_per_prompt = train_samples_per_prompt self.train_jsonl = ( "https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_train.jsonl.gz" ) self.test_jsonl = ( "https://openaipublic.blob.core.windows.net/simple-evals/drop_v0_dev.jsonl.gz" ) with gzip.GzipFile(fileobj=common.url_to_fileobj(self.train_jsonl, binary=True), mode="rb") as f: self.train_samples = list(map(json.loads, f.readlines())) with gzip.GzipFile(fileobj=common.url_to_fileobj(self.test_jsonl, binary=True), mode="rb") as f: self.test_samples = list(map(json.loads, f.readlines())) if self._num_examples: self.test_samples = random.Random(self.seed).sample( self.test_samples, self._num_examples ) def __call__(self, sampler: SamplerBase) -> EvalResult: rng = random.Random(self.seed) def fn(example: dict[str, str]): stuffing = rng.sample(self.train_samples, self._train_samples_per_prompt) # prompt = """TASK: Read the provided passage, then identify the correct answer to questions below.""" prompt = """You will be asked to read a passage and answer a question. Some examples of passages and Q&A are provided below.""" prompt += "\n\n# Examples" samples = stuffing + [example] for i, sample in enumerate(samples): is_test = i == len(stuffing) prompt += "\n# Your Task\n" if is_test else "" prompt += f""" --- {sample["context"]} """ a = sample["completion"] correct_answers = sample["ref_text"].split("|") if not is_test: prompt += a + "\n" else: prompt += """\n Think step by step, then write a line of the form "Answer: $ANSWER" at the end of your response. """ prompt_messages = [sampler._pack_message(content=prompt, role="user")] response_text = sampler(prompt_messages) match = re.search(ANSWER_PATTERN, response_text) extracted_answer = match.group(1) if match else response_text em_score, f1_score = drop_metric(extracted_answer, correct_answers) matches = [ fuzzy_match(extracted_answer, correct_answer) for correct_answer in correct_answers ] extracted_answers = [ extracted_answer for i in range(len(correct_answers)) if matches[i] ] score = True in matches html = common.jinja_env.from_string(HTML_JINJA).render( prompt_messages=prompt_messages, next_message=dict(content=extracted_answer, role="assistant"), score=score, correct_answer=correct_answers, extracted_answer=extracted_answers, ) convo = prompt_messages + [dict(content=extracted_answer, role="assistant")] return SingleEvalResult( html=html, score=score, convo=convo, metrics={"em_score": em_score, "f1_score": f1_score}, ) results = common.map_with_progress(fn, self.test_samples) return common.aggregate_results(results)