in src/accelerate/utils/dataclasses.py [0:0]
def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=False):
"Sets the mixed precision policy for FSDP"
mixed_precision_mapping = {
"fp8": torch.bfloat16,
"fp16": torch.float16,
"bf16": torch.bfloat16,
"fp32": torch.float32,
}
dtype = mixed_precision
if isinstance(mixed_precision, str):
dtype = mixed_precision_mapping.get(mixed_precision, None)
if dtype is None:
raise ValueError(
f"Invalid mixed precision: {mixed_precision}. Must be one of {list(mixed_precision_mapping.keys())}"
)
elif isinstance(mixed_precision, torch.dtype) and mixed_precision not in mixed_precision_mapping.values():
raise ValueError(
f"Invalid mixed precision: {mixed_precision}. Must be one of {list(mixed_precision_mapping.values())}"
)
buffer_type = torch.float32 if buffer_autocast else dtype
if self.fsdp_version == 1:
from torch.distributed.fsdp import MixedPrecision
elif self.fsdp_version == 2:
from torch.distributed.fsdp import MixedPrecisionPolicy as MixedPrecision
if override or self.mixed_precision_policy is None:
dtype_args = {"param_dtype": dtype, "reduce_dtype": dtype}
if self.fsdp_version == 1:
dtype_args["buffer_dtype"] = buffer_type
else:
dtype_args["output_dtype"] = dtype
# TODO(s1ro1): `cast_forward_inputs` for FSDP2?
self.mixed_precision_policy = MixedPrecision(**dtype_args)
elif isinstance(self.mixed_precision_policy, dict):
# Check for incompatible types
valid_keys = ["param_dtype", "reduce_dtype"] + (
["buffer_dtype"] if self.fsdp_version == 1 else ["output_dtype"]
)
missing_keys = [k for k in valid_keys if k not in self.mixed_precision_policy]
invalid_values = [
k for k, v in self.mixed_precision_policy.items() if v not in mixed_precision_mapping.values()
]
if missing_keys or invalid_values:
raise ValueError(
f"Invalid mixed precision policy: {self.mixed_precision_policy}. "
f"Must be a `dict` with keys {valid_keys}."
f"Values must be one of {list(mixed_precision_mapping.values())}"
)
self.mixed_precision_policy = MixedPrecision(**self.mixed_precision_policy)