evals/elsuite/utils.py (150 lines of code) (raw):
import copy
import re
import string
from collections import Counter, defaultdict
from typing import Optional, Union
from evals import CompletionFn
from evals.prompt.base import (
OpenAICreateChatPrompt,
OpenAICreatePrompt,
Prompt,
chat_prompt_to_text_prompt,
is_chat_prompt,
)
def get_answer(text, answer_prompt, ignore_case=False):
if ignore_case:
idx = text.lower().rfind(answer_prompt.lower())
else:
idx = text.rfind(answer_prompt)
if idx == -1:
return None
return text[idx : idx + len(answer_prompt)]
def get_consensus(answers):
counts = defaultdict(int)
for answer in answers:
counts[answer] += 1
counts[None] = 0
return max(counts, key=counts.get)
def normalize(s: str) -> str:
"""Lower text and remove punctuation, articles and extra whitespace."""
s = s.lower()
exclude = set(string.punctuation)
s = "".join(char for char in s if char not in exclude)
s = re.sub(r"\b(a|an|the)\b", " ", s)
s = " ".join(s.split())
return s
def fuzzy_match(s1: str, s2: str) -> bool:
s1 = normalize(s1)
s2 = normalize(s2)
if s1 == "" or s2 == "":
return s1 == s2
return s1 in s2 or s2 in s1
def get_scores_from_text(text: str) -> dict:
pattern = r"## (.+?)\n.+?(\d)/5"
matches = re.findall(pattern, text, re.DOTALL)
return {k: int(v) for k, v in dict(matches).items()}
def get_yesno_from_text(text: str) -> dict:
pattern = r"## (.+?)\n.+?([yn])"
matches = re.findall(pattern, text, re.DOTALL)
return {k: v for k, v in dict(matches).items()}
def get_letter_from_data(data: str) -> str:
last_y = (data.rfind("y"), "y")
last_n = (data.rfind("n"), "n")
char = max(last_y, last_n)[1]
return char
def f1_score(prediction: str, answers: list[str]) -> float:
def _f1_score(prediction: str, ground_truth: str):
prediction_tokens = normalize(prediction).split()
ground_truth_tokens = normalize(ground_truth).split()
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(prediction_tokens)
recall = 1.0 * num_same / len(ground_truth_tokens)
f1 = (2 * precision * recall) / (precision + recall)
return f1
return max([_f1_score(prediction, answer) for answer in answers])
def scrub_formatting_from_prompt(prompt):
scrubbed_prompt = copy.copy(prompt)
if is_chat_prompt(prompt):
for i, msg in enumerate(scrubbed_prompt):
if "content" in msg:
scrubbed_prompt[i]["content"] = msg["content"].replace("{", "{{").replace("}", "}}")
else:
scrubbed_prompt = scrubbed_prompt.replace("{", "{{").replace("}", "}}")
return scrubbed_prompt
def format_necessary(template: str, allow_missing: bool = False, **kwargs: dict[str, str]) -> str:
"""Format a template string with only necessary kwargs."""
keys = [k[1] for k in string.Formatter().parse(template) if k[1]]
if allow_missing:
assert (
len([k for k in keys if k in kwargs]) > 0
), f"Required: {keys}, got: {sorted(kwargs)}, no inputs are used.\nTemplate:\n{template}"
cur_keys = {k: kwargs.get(k, "{" + k + "}") for k in keys}
else:
assert all(
k in kwargs for k in keys
), f"Required: {keys}, got: {sorted(kwargs)}.\nTemplate:\n{template}"
cur_keys = {k: kwargs[k] for k in keys}
return template.format(**cur_keys)
def format_prompt(
prompt: OpenAICreatePrompt, allow_missing: bool = False, **kwargs: dict[str, str]
) -> OpenAICreatePrompt:
"""Format a prompt with only necessary kwargs."""
# if any input kwargs is chat prompt, convert to text prompt
kwargs = {
k: chat_prompt_to_text_prompt(v, for_completion=False) if is_chat_prompt(v) else v
for k, v in kwargs.items()
}
if is_chat_prompt(prompt):
new_prompt = []
for msg in prompt:
formatted_msg = copy.copy(msg)
if "content" in formatted_msg:
formatted_msg["content"] = format_necessary(
formatted_msg["content"], allow_missing=allow_missing, **kwargs
)
new_prompt.append(formatted_msg)
prompt = new_prompt
else:
# Prompt is a string
prompt = format_necessary(prompt, allow_missing=allow_missing, **kwargs)
return prompt
class PromptFn:
"""
Wrap calls to a completion_fn with a prompt template with applicable keyword args.
This will pass many args relevant to OpenAI Completion API, may be ignored by other completion_fn.
"""
def __init__(
self,
prompt: Union[OpenAICreatePrompt, OpenAICreateChatPrompt, Prompt],
completion_fn: CompletionFn,
max_tokens: int,
temperature: int = 0,
n_samples: Optional[int] = None,
completion_kwargs: Optional[dict] = {},
):
self.prompt = prompt
self.max_tokens = max_tokens
self.completion_fn = completion_fn
self.temperature = temperature
self.completion_kwargs = completion_kwargs
self.n_samples = n_samples
def __call__(self, **kwargs):
# if any input kwargs is chat prompt, convert to text prompt
kwargs = {
k: chat_prompt_to_text_prompt(v, for_completion=False) if is_chat_prompt(v) else v
for k, v in kwargs.items()
}
if is_chat_prompt(self.prompt):
prompt = []
for msg in self.prompt:
formatted_msg = copy.copy(msg)
if "content" in formatted_msg:
formatted_msg["content"] = format_necessary(formatted_msg["content"], **kwargs)
prompt.append(formatted_msg)
else:
# Prompt is a string
prompt = format_necessary(self.prompt, **kwargs)
result = self.completion_fn(
prompt=prompt,
max_tokens=self.max_tokens,
temperature=self.temperature,
top_p=1,
frequency_penalty=0,
presence_penalty=0,
n=(1 if self.n_samples is None else self.n_samples),
**self.completion_kwargs,
)
sampled = result.get_completions()[0]
return sampled, prompt