in src/accelerate/utils/dataclasses.py [0:0]
def set_mixed_precision(self, mixed_precision):
ds_config = self.deepspeed_config
kwargs = {
"fp16.enabled": mixed_precision == "fp16",
# When training in fp8, we still rely on bf16 autocast for the core mixed precision
"bf16.enabled": mixed_precision in ("bf16", "fp8"),
}
if mixed_precision == "fp16":
if "fp16" not in ds_config:
ds_config["fp16"] = {"enabled": True, "auto_cast": True}
elif mixed_precision in ("bf16", "fp8"):
if "bf16" not in ds_config:
ds_config["bf16"] = {"enabled": True}
if mixed_precision == "fp8" and self.enable_msamp:
if "msamp" not in ds_config:
ds_config["msamp"] = {"enabled": True, "opt_level": self.msamp_opt_level}
if mixed_precision != "no":
diff_dtype = "bf16" if mixed_precision == "fp16" else "fp16"
if str(ds_config.get(diff_dtype, {}).get("enabled", "False")).lower() == "true":
raise ValueError(
f"`--mixed_precision` arg cannot be set to `{mixed_precision}` when `{diff_dtype}` is set in the DeepSpeed config file."
)
for dtype in ["fp16", "bf16"]:
if dtype not in ds_config:
ds_config[dtype] = {"enabled": False}
self.fill_match("fp16.enabled", must_match=False, **kwargs)
self.fill_match("bf16.enabled", must_match=False, **kwargs)