in src/accelerate/utils/dataclasses.py [0:0]
def __post_init__(self):
"""
Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
"""
if not isinstance(self.load_in_8bit, bool):
raise ValueError("load_in_8bit must be a boolean")
if not isinstance(self.load_in_4bit, bool):
raise ValueError("load_in_4bit must be a boolean")
if self.load_in_4bit and self.load_in_8bit:
raise ValueError("load_in_4bit and load_in_8bit can't be both True")
if not self.load_in_4bit and not self.load_in_8bit:
raise ValueError("load_in_4bit and load_in_8bit can't be both False")
if not isinstance(self.llm_int8_threshold, (int, float)):
raise ValueError("llm_int8_threshold must be a float or an int")
if not isinstance(self.bnb_4bit_quant_type, str):
raise ValueError("bnb_4bit_quant_type must be a string")
elif self.bnb_4bit_quant_type not in ["fp4", "nf4"]:
raise ValueError(f"bnb_4bit_quant_type must be in ['fp4','nf4'] but found {self.bnb_4bit_quant_type}")
if not isinstance(self.bnb_4bit_use_double_quant, bool):
raise ValueError("bnb_4bit_use_double_quant must be a boolean")
if isinstance(self.bnb_4bit_compute_dtype, str):
if self.bnb_4bit_compute_dtype == "fp32":
self.bnb_4bit_compute_dtype = torch.float32
elif self.bnb_4bit_compute_dtype == "fp16":
self.bnb_4bit_compute_dtype = torch.float16
elif self.bnb_4bit_compute_dtype == "bf16":
self.bnb_4bit_compute_dtype = torch.bfloat16
else:
raise ValueError(
f"bnb_4bit_compute_dtype must be in ['fp32','fp16','bf16'] but found {self.bnb_4bit_compute_dtype}"
)
elif not isinstance(self.bnb_4bit_compute_dtype, torch.dtype):
raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype")
if self.skip_modules is not None and not isinstance(self.skip_modules, list):
raise ValueError("skip_modules must be a list of strings")
if self.keep_in_fp32_modules is not None and not isinstance(self.keep_in_fp32_modules, list):
raise ValueError("keep_in_fp_32_modules must be a list of strings")
if self.load_in_4bit:
self.target_dtype = CustomDtype.INT4
if self.load_in_8bit:
self.target_dtype = torch.int8
if self.load_in_4bit and self.llm_int8_threshold != 6.0:
warnings.warn("llm_int8_threshold can only be used for model loaded in 8bit")
if isinstance(self.torch_dtype, str):
if self.torch_dtype == "fp32":
self.torch_dtype = torch.float32
elif self.torch_dtype == "fp16":
self.torch_dtype = torch.float16
elif self.torch_dtype == "bf16":
self.torch_dtype = torch.bfloat16
else:
raise ValueError(f"torch_dtype must be in ['fp32','fp16','bf16'] but found {self.torch_dtype}")
if self.load_in_8bit and self.torch_dtype is None:
self.torch_dtype = torch.float16
if self.load_in_4bit and self.torch_dtype is None:
self.torch_dtype = self.bnb_4bit_compute_dtype
if not isinstance(self.torch_dtype, torch.dtype):
raise ValueError("torch_dtype must be a torch.dtype")