lm_eval/loggers/wandb_logger.py (254 lines of code) (raw):
import copy
import json
import logging
from typing import Any, Dict, List, Literal, Tuple
import numpy as np
import pandas as pd
from packaging.version import Version
from lm_eval.loggers.utils import _handle_non_serializable, remove_none_pattern
logger = logging.getLogger(__name__)
def get_wandb_printer() -> Literal["Printer"]:
"""Returns a wandb printer instance for pretty stdout."""
from wandb.sdk.lib.printer import get_printer
from wandb.sdk.wandb_settings import Settings
printer = get_printer(Settings()._jupyter)
return printer
class WandbLogger:
def __init__(self, **kwargs) -> None:
"""Attaches to wandb logger if already initialized. Otherwise, passes kwargs to wandb.init()
Args:
kwargs Optional[Any]: Arguments for configuration.
Parse and log the results returned from evaluator.simple_evaluate() with:
wandb_logger.post_init(results)
wandb_logger.log_eval_result()
wandb_logger.log_eval_samples(results["samples"])
"""
try:
import wandb
assert Version(wandb.__version__) >= Version("0.13.6")
if Version(wandb.__version__) < Version("0.13.6"):
wandb.require("report-editing:v0")
except Exception as e:
logger.warning(
"To use the wandb reporting functionality please install wandb>=0.13.6.\n"
"To install the latest version of wandb run `pip install wandb --upgrade`\n"
f"{e}"
)
self.wandb_args: Dict[str, Any] = kwargs
# initialize a W&B run
if wandb.run is None:
self.run = wandb.init(**self.wandb_args)
else:
self.run = wandb.run
self.printer = get_wandb_printer()
def post_init(self, results: Dict[str, Any]) -> None:
self.results: Dict[str, Any] = copy.deepcopy(results)
self.task_names: List[str] = list(results.get("results", {}).keys())
self.group_names: List[str] = list(results.get("groups", {}).keys())
def _get_config(self) -> Dict[str, Any]:
"""Get configuration parameters."""
self.task_configs = self.results.get("configs", {})
cli_configs = self.results.get("config", {})
configs = {
"task_configs": self.task_configs,
"cli_configs": cli_configs,
}
return configs
def _sanitize_results_dict(self) -> Tuple[Dict[str, str], Dict[str, Any]]:
"""Sanitize the results dictionary."""
_results = copy.deepcopy(self.results.get("results", dict()))
# Remove None from the metric string name
tmp_results = copy.deepcopy(_results)
for task_name in self.task_names:
task_result = tmp_results.get(task_name, dict())
for metric_name, metric_value in task_result.items():
_metric_name, removed = remove_none_pattern(metric_name)
if removed:
_results[task_name][_metric_name] = metric_value
_results[task_name].pop(metric_name)
# remove string valued keys from the results dict
wandb_summary = {}
for task in self.task_names:
task_result = _results.get(task, dict())
for metric_name, metric_value in task_result.items():
if isinstance(metric_value, str):
wandb_summary[f"{task}/{metric_name}"] = metric_value
for summary_metric, summary_value in wandb_summary.items():
_task, _summary_metric = summary_metric.split("/")
_results[_task].pop(_summary_metric)
tmp_results = copy.deepcopy(_results)
for task_name, task_results in tmp_results.items():
for metric_name, metric_value in task_results.items():
_results[f"{task_name}/{metric_name}"] = metric_value
_results[task_name].pop(metric_name)
for task in self.task_names:
_results.pop(task)
return wandb_summary, _results
def _log_results_as_table(self) -> None:
"""Generate and log evaluation results as a table to W&B."""
columns = [
"Version",
"Filter",
"num_fewshot",
"Metric",
"Value",
"Stderr",
]
def make_table(columns: List[str], key: str = "results"):
import wandb
table = wandb.Table(columns=columns)
results = copy.deepcopy(self.results)
for k, dic in results.get(key).items():
if k in self.group_names and not key == "groups":
continue
version = results.get("versions").get(k)
if version == "N/A":
version = None
n = results.get("n-shot").get(k)
for (mf), v in dic.items():
m, _, f = mf.partition(",")
if m.endswith("_stderr"):
continue
if m == "alias":
continue
if m + "_stderr" + "," + f in dic:
se = dic[m + "_stderr" + "," + f]
if se != "N/A":
se = "%.4f" % se
table.add_data(*[k, version, f, n, m, str(v), str(se)])
else:
table.add_data(*[k, version, f, n, m, str(v), ""])
return table
# log the complete eval result to W&B Table
table = make_table(["Tasks"] + columns, "results")
self.run.log({"evaluation/eval_results": table})
if "groups" in self.results.keys():
table = make_table(["Groups"] + columns, "groups")
self.run.log({"evaluation/group_eval_results": table})
def _log_results_as_artifact(self) -> None:
"""Log results as JSON artifact to W&B."""
import wandb
dumped = json.dumps(
self.results, indent=2, default=_handle_non_serializable, ensure_ascii=False
)
artifact = wandb.Artifact("results", type="eval_results")
with artifact.new_file("results.json", mode="w", encoding="utf-8") as f:
f.write(dumped)
self.run.log_artifact(artifact)
def log_eval_result(self) -> None:
"""Log evaluation results to W&B."""
# Log configs to wandb
configs = self._get_config()
self.run.config.update(configs)
wandb_summary, self.wandb_results = self._sanitize_results_dict()
# update wandb.run.summary with items that were removed
self.run.summary.update(wandb_summary)
# Log the evaluation metrics to wandb
self.run.log(self.wandb_results)
# Log the evaluation metrics as W&B Table
self._log_results_as_table()
# Log the results dict as json to W&B Artifacts
self._log_results_as_artifact()
def _generate_dataset(
self, data: List[Dict[str, Any]], config: Dict[str, Any]
) -> pd.DataFrame:
"""Generate a dataset from evaluation data.
Args:
data (List[Dict[str, Any]]): The data to generate a dataset for.
config (Dict[str, Any]): The configuration of the task.
Returns:
pd.DataFrame: A dataframe that is ready to be uploaded to W&B.
"""
ids = [x["doc_id"] for x in data]
labels = [x["target"] for x in data]
instance = [""] * len(ids)
resps = [""] * len(ids)
filtered_resps = [""] * len(ids)
model_outputs = {}
metrics_list = config["metric_list"]
metrics = {}
for metric in metrics_list:
metric = metric.get("metric")
if metric in ["word_perplexity", "byte_perplexity", "bits_per_byte"]:
metrics[f"{metric}_loglikelihood"] = [x[metric][0] for x in data]
if metric in ["byte_perplexity", "bits_per_byte"]:
metrics[f"{metric}_bytes"] = [x[metric][1] for x in data]
else:
metrics[f"{metric}_words"] = [x[metric][1] for x in data]
else:
metrics[metric] = [x[metric] for x in data]
if config["output_type"] == "loglikelihood":
instance = [x["arguments"][0][0] for x in data]
labels = [x["arguments"][0][1] for x in data]
resps = [
f'log probability of continuation is {x["resps"][0][0][0]} '
+ "\n\n"
+ "continuation will {} generated with greedy sampling".format(
"not be" if not x["resps"][0][0][1] else "be"
)
for x in data
]
filtered_resps = [
f'log probability of continuation is {x["filtered_resps"][0][0]} '
+ "\n\n"
+ "continuation will {} generated with greedy sampling".format(
"not be" if not x["filtered_resps"][0][1] else "be"
)
for x in data
]
elif config["output_type"] == "multiple_choice":
instance = [x["arguments"][0][0] for x in data]
choices = [
"\n".join([f"{idx}. {y[1]}" for idx, y in enumerate(x["arguments"])])
for x in data
]
resps = [np.argmax([n[0][0] for n in x["resps"]]) for x in data]
filtered_resps = [
np.argmax([n[0] for n in x["filtered_resps"]]) for x in data
]
elif config["output_type"] == "loglikelihood_rolling":
instance = [x["arguments"][0][0] for x in data]
resps = [x["resps"][0][0] for x in data]
filtered_resps = [x["filtered_resps"][0] for x in data]
elif config["output_type"] == "generate_until":
instance = [x["arguments"][0][0] for x in data]
resps = [x["resps"][0][0] for x in data]
filtered_resps = [x["filtered_resps"][0] for x in data]
model_outputs["raw_predictions"] = resps
model_outputs["filtered_predictions"] = filtered_resps
df_data = {
"id": ids,
"data": instance,
}
if config["output_type"] == "multiple_choice":
df_data["choices"] = choices
tmp_data = {
"input_len": [len(x) for x in instance],
"labels": labels,
"output_type": config["output_type"],
}
df_data.update(tmp_data)
df_data.update(model_outputs)
df_data.update(metrics)
return pd.DataFrame(df_data)
def _log_samples_as_artifact(
self, data: List[Dict[str, Any]], task_name: str
) -> None:
import wandb
# log the samples as an artifact
dumped = json.dumps(
data,
indent=2,
default=_handle_non_serializable,
ensure_ascii=False,
)
artifact = wandb.Artifact(f"{task_name}", type="samples_by_task")
with artifact.new_file(
f"{task_name}_eval_samples.json", mode="w", encoding="utf-8"
) as f:
f.write(dumped)
self.run.log_artifact(artifact)
# artifact.wait()
def log_eval_samples(self, samples: Dict[str, List[Dict[str, Any]]]) -> None:
"""Log evaluation samples to W&B.
Args:
samples (Dict[str, List[Dict[str, Any]]]): Evaluation samples for each task.
"""
task_names: List[str] = [
x for x in self.task_names if x not in self.group_names
]
ungrouped_tasks = []
tasks_by_groups = {}
for task_name in task_names:
group_names = self.task_configs[task_name].get("group", None)
if group_names:
if isinstance(group_names, str):
group_names = [group_names]
for group_name in group_names:
if not tasks_by_groups.get(group_name):
tasks_by_groups[group_name] = [task_name]
else:
tasks_by_groups[group_name].append(task_name)
else:
ungrouped_tasks.append(task_name)
for task_name in ungrouped_tasks:
eval_preds = samples[task_name]
# log the samples as a W&B Table
df = self._generate_dataset(eval_preds, self.task_configs.get(task_name))
self.run.log({f"{task_name}_eval_results": df})
# log the samples as a json file as W&B Artifact
self._log_samples_as_artifact(eval_preds, task_name)
for group, grouped_tasks in tasks_by_groups.items():
grouped_df = pd.DataFrame()
for task_name in grouped_tasks:
eval_preds = samples[task_name]
df = self._generate_dataset(
eval_preds, self.task_configs.get(task_name)
)
df["group"] = group
df["task"] = task_name
grouped_df = pd.concat([grouped_df, df], ignore_index=True)
# log the samples as a json file as W&B Artifact
self._log_samples_as_artifact(eval_preds, task_name)
self.run.log({f"{group}_eval_results": grouped_df})