optimum/habana/accelerate/utils/dataclasses.py (17 lines of code) (raw):
from dataclasses import dataclass
from accelerate.utils import KwargsHandler
@dataclass
class GaudiTERecipeKwargs(KwargsHandler):
"""
Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision training with `transformer-engine`.
Adapted from: https://github.com/huggingface/accelerate/blob/v0.27.2/src/accelerate/utils/dataclasses.py#L180
Args:
margin (`int`, *optional*, defaults to 0):
The margin to use for the scaling factor computation.
interval (`int`, *optional*, defaults to 16):
The interval to use for how often the scaling factor is recomputed.
fp8_format (`str`, *optional*, defaults to "HYBRID"):
The format to use for the FP8 recipe. Must be one of `E5M2` or `HYBRID`.
amax_history_len (`int`, *optional*, defaults to 1):
The length of the history to use for the scaling factor computation
amax_compute_algo (`str`, *optional*, defaults to "most_recent"):
The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`.
reduce_amax (`bool`, *optional*, defaults to "False"):
By default, if `torch.distributed` is initialized, the `amax` value for FP8
tensors is reduced across the `fp8_group` (specified in the `fp8_autocast`
call). This keeps the amaxes and scaling factors synced across the given
distributed group. If set to `False`, this reduction is skipped and every
HPU maintains local amaxes and scaling factors. To ensure results are
numerically identical across checkpointing boundaries in this case, all
ranks must checkpoint in order to store the local tensors.
"""
margin: int = 0
interval: int = 16
fp8_format: str = "HYBRID"
amax_compute_algo: str = "most_recent"
amax_history_len: int = 1
reduce_amax: bool = False
def __post_init__(self):
self.fp8_format = self.fp8_format.upper()
assert self.fp8_format in ("E5M2", "HYBRID"), "Only E5M2 and HYBRID FP8 formats are currently supported."
assert self.amax_compute_algo in (
"max",
"most_recent",
), "Only max and most_recent `amax_compute_algo` modes are currently supported."