optimum_benchmark/scenarios/energy_star/config.py (76 lines of code) (raw):
from dataclasses import dataclass, field
from logging import getLogger
from typing import Any, Dict, Union
from ...system_utils import is_rocm_system
from ..config import ScenarioConfig
LOGGER = getLogger("energy_star")
INPUT_SHAPES = {"batch_size": 1}
@dataclass
class EnergyStarConfig(ScenarioConfig):
name: str = "energy_star"
_target_: str = "optimum_benchmark.scenarios.energy_star.scenario.EnergyStarScenario"
# dataset options
dataset_name: str = field(default="", metadata={"help": "Name of the dataset on the HF Hub."})
dataset_config: str = field(default="", metadata={"help": "Name of the config of the dataset."})
dataset_split: str = field(default="train", metadata={"help": "Dataset split to use."})
num_samples: int = field(default=-1, metadata={"help": "Number of samples to select in the dataset. -1 means all."})
input_shapes: Dict[str, Any] = field(
default_factory=dict,
metadata={"help": "Input shapes for the model. Missing keys will be filled with default values."},
)
# text dataset options
text_column_name: str = field(default="text", metadata={"help": "Name of the column with the text input."})
truncation: Union[bool, str] = field(default=True, metadata={"help": "To truncate the inputs."})
max_length: int = field(
default=-1, metadata={"help": "Maximum length to use by one of the truncation/padding parameters"}
)
dataset_prefix1: str = field(default="", metadata={"help": "Prefix to add to text2textgeneration input."})
dataset_prefix2: str = field(default="", metadata={"help": "Prefix to add to text2textgeneration input."})
t5_task: str = field(default="", metadata={"help": "Task for categorizing text2textgeneration tasks."})
# image dataset options
image_column_name: str = field(default="image", metadata={"help": "Name of the column with the image input."})
resize: Union[bool, str] = field(default=False, metadata={"help": "To resize the input images."})
# qa dataset options
question_column_name: str = field(default="question", metadata={"help": "Name of the column with the question."})
context_column_name: str = field(default="context", metadata={"help": "Name of the column with the context."})
# sts dataset options
sentence1_column_name: str = field(
default="sentence1", metadata={"help": "Name of the column with the first sentence."}
)
sentence2_column_name: str = field(
default="sentence2", metadata={"help": "Name of the column with the second sentence."}
)
# asr dataset options
audio_column_name: str = field(default="audio", metadata={"help": "Name of the column with the audio."})
# scenario options
energy: bool = field(default=True, metadata={"help": "Whether to measure energy."})
memory: bool = field(default=False, metadata={"help": "Whether to measure memory."})
latency: bool = field(default=False, metadata={"help": "Whether to measure latency."})
warmup_runs: int = field(default=10, metadata={"help": "Number of warmup runs to perform before scenarioing"})
# methods kwargs
forward_kwargs: Dict[str, Any] = field(
default_factory=dict, metadata={"help": "Keyword arguments to pass to the forward method of the model."}
)
generate_kwargs: Dict[str, Any] = field(
default_factory=dict, metadata={"help": "Keyword arguments to pass to the generate method of the model."}
)
call_kwargs: Dict[str, Any] = field(
default_factory=dict, metadata={"help": "Keyword arguments to pass to the __call__ method of the pipeline."}
)
def __post_init__(self):
super().__post_init__()
self.input_shapes = {**INPUT_SHAPES, **self.input_shapes}
if (
"max_new_tokens" in self.generate_kwargs
and "min_new_tokens" in self.generate_kwargs
and self.generate_kwargs["max_new_tokens"] != self.generate_kwargs["min_new_tokens"]
):
raise ValueError(
"Setting `min_new_tokens` and `max_new_tokens` to different values results in non-deterministic behavior."
)
elif "max_new_tokens" in self.generate_kwargs and "min_new_tokens" not in self.generate_kwargs:
LOGGER.warning(
"Setting `max_new_tokens` without `min_new_tokens` results in non-deterministic behavior. "
"Setting `min_new_tokens` to `max_new_tokens`."
)
self.generate_kwargs["min_new_tokens"] = self.generate_kwargs["max_new_tokens"]
elif "min_new_tokens" in self.generate_kwargs and "max_new_tokens" not in self.generate_kwargs:
LOGGER.warning(
"Setting `min_new_tokens` without `max_new_tokens` results in non-deterministic behavior. "
"Setting `max_new_tokens` to `min_new_tokens`."
)
self.generate_kwargs["max_new_tokens"] = self.generate_kwargs["min_new_tokens"]
if is_rocm_system():
raise ValueError("Energy measurement through codecarbon is not yet available on ROCm-powered devices.")