optimum_benchmark/hub_utils.py (107 lines of code) (raw):

import time from dataclasses import asdict, dataclass from json import dump, load from logging import getLogger from pathlib import Path from tempfile import TemporaryDirectory from typing import Any, Dict, Optional, Union import pandas as pd from flatten_dict import flatten, unflatten from huggingface_hub import HfApi from huggingface_hub.utils import HfHubHTTPError from transformers.utils.hub import http_user_agent from typing_extensions import Self LOGGER = getLogger("hub_utils") HF_API = HfApi(user_agent=http_user_agent()) class classproperty: def __init__(self, fget): self.fget = fget def __get__(self, obj, owner): return self.fget(owner) @dataclass class PushToHubMixin: """ A Mixin to push artifacts to the Hugging Face Hub """ # DICTIONARY/JSON API def to_dict(self, flat=False) -> Dict[str, Any]: data = asdict(self) if flat: data = flatten(data, reducer="dot") return data @classmethod def from_dict(cls, data: Dict[str, Any]) -> "PushToHubMixin": return cls(**data) def save_json(self, path: Union[str, Path], flat: bool = False) -> None: with open(path, "w") as f: dump(self.to_dict(flat=flat), f, indent=4) @classmethod def from_json(cls, path: Union[str, Path]) -> Self: with open(path, "r") as f: data = load(f) return cls.from_dict(data) # DATAFRAME/CSV API def to_dataframe(self) -> pd.DataFrame: flat_dict_data = self.to_dict(flat=True) return pd.DataFrame.from_dict(flat_dict_data, orient="index").T @classmethod def from_dataframe(cls, df: pd.DataFrame) -> Self: data = df.to_dict(orient="records")[0] for k, v in data.items(): if isinstance(v, str) and v.startswith("[") and v.endswith("]"): # we correct lists that were converted to strings data[k] = eval(v) if v != v: # we correct nan to None data[k] = None data = unflatten(data, splitter="dot") return cls.from_dict(data) def save_csv(self, path: Union[str, Path]) -> None: self.to_dataframe().to_csv(path, index=False) @classmethod def from_csv(cls, path: Union[str, Path]) -> Self: return cls.from_dataframe(pd.read_csv(path)) # HUGGING FACE HUB API def push_to_hub( self, repo_id: str, filename: Optional[str] = None, subfolder: Optional[str] = None, **kwargs ) -> None: filename = str(filename or self.default_filename) subfolder = str(subfolder or self.default_subfolder) token = kwargs.pop("token", None) private = kwargs.pop("private", False) exist_ok = kwargs.pop("exist_ok", True) repo_type = kwargs.pop("repo_type", "dataset") HF_API.create_repo(repo_id, token=token, private=private, exist_ok=exist_ok, repo_type=repo_type) with TemporaryDirectory() as tmpdir: path_in_repo = (Path(subfolder) / filename).as_posix() path_or_fileobj = Path(tmpdir) / filename self.save_json(path_or_fileobj) HF_API.upload_file( repo_id=repo_id, path_in_repo=path_in_repo, path_or_fileobj=path_or_fileobj, repo_type=repo_type, token=token, **kwargs, ) @classmethod def from_pretrained( cls, repo_id: str, filename: Optional[str] = None, subfolder: Optional[str] = None, **kwargs ) -> Self: filename = str(filename or cls.default_filename) subfolder = str(subfolder or cls.default_subfolder) repo_type = kwargs.pop("repo_type", "dataset") try: resolved_file = HF_API.hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type, **kwargs ) except HfHubHTTPError as e: LOGGER.warning("Error while downloading from Hugging Face Hub") if "Client Error: Too Many Requests for url" in str(e): LOGGER.warning("Client Error: Too Many Requests for url. Retrying in 15 seconds.") time.sleep(15) resolved_file = HF_API.hf_hub_download( repo_id=repo_id, filename=filename, subfolder=subfolder, repo_type=repo_type, **kwargs ) else: raise e config_dict = cls.from_json(resolved_file) return config_dict @classproperty def default_filename(self) -> str: return "file.json" @classproperty def default_subfolder(self) -> str: return "benchmarks"