scripts/transformers/utils.py (13 lines of code) (raw):

import json from typing import Tuple from datasets import Dataset def get_label_mappings(dataset: Dataset) -> Tuple[int, dict, dict]: """Returns the label mappings of the dataset.""" label_ids = dataset.unique("label") label_names = dataset.unique("label_text") label2id = {label: idx for label, idx in zip(label_names, label_ids)} id2label = {idx: label for label, idx in label2id.items()} num_labels = len(label_ids) return num_labels, label2id, id2label def save_metrics(metrics: dict, metrics_filepath): with open(metrics_filepath, "w") as f: json.dump(metrics, f)