in src/alignment/model_utils.py [0:0]
def get_quantization_config(model_args: ModelArguments) -> BitsAndBytesConfig | None:
if model_args.load_in_4bit:
compute_dtype = torch.float16
if model_args.torch_dtype not in {"auto", None}:
compute_dtype = getattr(torch, model_args.torch_dtype)
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
bnb_4bit_quant_storage=model_args.bnb_4bit_quant_storage,
).to_dict()
elif model_args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
).to_dict()
else:
quantization_config = None
return quantization_config