def set_mixed_precision()

in src/accelerate/utils/dataclasses.py [0:0]


    def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=False):
        "Sets the mixed precision policy for FSDP"
        mixed_precision_mapping = {
            "fp8": torch.bfloat16,
            "fp16": torch.float16,
            "bf16": torch.bfloat16,
            "fp32": torch.float32,
        }
        dtype = mixed_precision
        if isinstance(mixed_precision, str):
            dtype = mixed_precision_mapping.get(mixed_precision, None)
            if dtype is None:
                raise ValueError(
                    f"Invalid mixed precision: {mixed_precision}. Must be one of {list(mixed_precision_mapping.keys())}"
                )
        elif isinstance(mixed_precision, torch.dtype) and mixed_precision not in mixed_precision_mapping.values():
            raise ValueError(
                f"Invalid mixed precision: {mixed_precision}. Must be one of {list(mixed_precision_mapping.values())}"
            )

        buffer_type = torch.float32 if buffer_autocast else dtype

        if self.fsdp_version == 1:
            from torch.distributed.fsdp import MixedPrecision
        elif self.fsdp_version == 2:
            from torch.distributed.fsdp import MixedPrecisionPolicy as MixedPrecision

        if override or self.mixed_precision_policy is None:
            dtype_args = {"param_dtype": dtype, "reduce_dtype": dtype}
            if self.fsdp_version == 1:
                dtype_args["buffer_dtype"] = buffer_type
            else:
                dtype_args["output_dtype"] = dtype
            # TODO(s1ro1): `cast_forward_inputs` for FSDP2?
            self.mixed_precision_policy = MixedPrecision(**dtype_args)
        elif isinstance(self.mixed_precision_policy, dict):
            # Check for incompatible types
            valid_keys = ["param_dtype", "reduce_dtype"] + (
                ["buffer_dtype"] if self.fsdp_version == 1 else ["output_dtype"]
            )
            missing_keys = [k for k in valid_keys if k not in self.mixed_precision_policy]
            invalid_values = [
                k for k, v in self.mixed_precision_policy.items() if v not in mixed_precision_mapping.values()
            ]
            if missing_keys or invalid_values:
                raise ValueError(
                    f"Invalid mixed precision policy: {self.mixed_precision_policy}. "
                    f"Must be a `dict` with keys {valid_keys}."
                    f"Values must be one of {list(mixed_precision_mapping.values())}"
                )
            self.mixed_precision_policy = MixedPrecision(**self.mixed_precision_policy)