def __post_init__()

in src/accelerate/utils/dataclasses.py [0:0]


    def __post_init__(self):
        """
        Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
        """
        if not isinstance(self.load_in_8bit, bool):
            raise ValueError("load_in_8bit must be a boolean")

        if not isinstance(self.load_in_4bit, bool):
            raise ValueError("load_in_4bit must be a boolean")

        if self.load_in_4bit and self.load_in_8bit:
            raise ValueError("load_in_4bit and load_in_8bit can't be both True")

        if not self.load_in_4bit and not self.load_in_8bit:
            raise ValueError("load_in_4bit and load_in_8bit can't be both False")

        if not isinstance(self.llm_int8_threshold, (int, float)):
            raise ValueError("llm_int8_threshold must be a float or an int")

        if not isinstance(self.bnb_4bit_quant_type, str):
            raise ValueError("bnb_4bit_quant_type must be a string")
        elif self.bnb_4bit_quant_type not in ["fp4", "nf4"]:
            raise ValueError(f"bnb_4bit_quant_type must be in ['fp4','nf4'] but found {self.bnb_4bit_quant_type}")

        if not isinstance(self.bnb_4bit_use_double_quant, bool):
            raise ValueError("bnb_4bit_use_double_quant must be a boolean")

        if isinstance(self.bnb_4bit_compute_dtype, str):
            if self.bnb_4bit_compute_dtype == "fp32":
                self.bnb_4bit_compute_dtype = torch.float32
            elif self.bnb_4bit_compute_dtype == "fp16":
                self.bnb_4bit_compute_dtype = torch.float16
            elif self.bnb_4bit_compute_dtype == "bf16":
                self.bnb_4bit_compute_dtype = torch.bfloat16
            else:
                raise ValueError(
                    f"bnb_4bit_compute_dtype must be in ['fp32','fp16','bf16'] but found {self.bnb_4bit_compute_dtype}"
                )
        elif not isinstance(self.bnb_4bit_compute_dtype, torch.dtype):
            raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype")

        if self.skip_modules is not None and not isinstance(self.skip_modules, list):
            raise ValueError("skip_modules must be a list of strings")

        if self.keep_in_fp32_modules is not None and not isinstance(self.keep_in_fp32_modules, list):
            raise ValueError("keep_in_fp_32_modules must be a list of strings")

        if self.load_in_4bit:
            self.target_dtype = CustomDtype.INT4

        if self.load_in_8bit:
            self.target_dtype = torch.int8

        if self.load_in_4bit and self.llm_int8_threshold != 6.0:
            warnings.warn("llm_int8_threshold can only be used for model loaded in 8bit")

        if isinstance(self.torch_dtype, str):
            if self.torch_dtype == "fp32":
                self.torch_dtype = torch.float32
            elif self.torch_dtype == "fp16":
                self.torch_dtype = torch.float16
            elif self.torch_dtype == "bf16":
                self.torch_dtype = torch.bfloat16
            else:
                raise ValueError(f"torch_dtype must be in ['fp32','fp16','bf16'] but found {self.torch_dtype}")
        if self.load_in_8bit and self.torch_dtype is None:
            self.torch_dtype = torch.float16

        if self.load_in_4bit and self.torch_dtype is None:
            self.torch_dtype = self.bnb_4bit_compute_dtype

        if not isinstance(self.torch_dtype, torch.dtype):
            raise ValueError("torch_dtype must be a torch.dtype")