in scripts/run_prompt_creation.py [0:0]
def get_quantization_config(model_args: ModelArguments) -> Union[BitsAndBytesConfig, None]:
if model_args.load_in_4bit:
compute_dtype = torch.float16
if model_args.torch_dtype not in {"auto", None}:
compute_dtype = getattr(torch, model_args.torch_dtype)
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_quant_type=model_args.bnb_4bit_quant_type,
bnb_4bit_use_double_quant=model_args.use_bnb_nested_quant,
)
elif model_args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
)
else:
quantization_config = None
return quantization_config