in trl/trainer/sft_trainer.py [0:0]
def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel:
"""Prepares a model for PEFT training."""
if not is_peft_available():
raise ImportError("To use PeftModel, you need to install the `peft` library.")
if not isinstance(peft_config, PeftConfig):
raise ValueError(
f"Expected PeftConfig object but got {type(peft_config)}. If you want to use the PeftModel, you need "
"to pass a PeftConfig object to the SFTTrainer."
)
if isinstance(model, PeftModel):
return model
# Handle quantized models (QLoRA)
is_qlora = getattr(model, "is_loaded_in_4bit", False) or getattr(model, "is_loaded_in_8bit", False)
is_sharded_qlora = False
if getattr(model, "is_loaded_in_4bit", False):
# Check if model is sharded (FSDP/DS-Zero3)
for _, param in model.named_parameters():
if param.__class__.__name__ == "Params4bit":
is_sharded_qlora = param.data.device.type in {"cpu", "meta"}
break
# Prepare model for kbit training if needed
if is_qlora and not is_sharded_qlora:
model = self._prepare_model_for_kbit_training(model, args)
# Disable gradient checkpointing as it's handled by prepare_model_for_kbit_training
args = dataclasses.replace(args, gradient_checkpointing=False)
elif args.gradient_checkpointing:
model = self._enable_gradient_checkpointing(model, args)
# Create PEFT model
if (
version.parse(peft.__version__) >= version.parse("0.12") # autocast_adapter_dtype introduced in 0.12
and getattr(model, "is_loaded_in_4bit", False)
and is_sharded_qlora
):
model = get_peft_model(model, peft_config, autocast_adapter_dtype=False)
else:
model = get_peft_model(model, peft_config)
# Handle bf16 casting for 4-bit models
if args.bf16 and getattr(model, "is_loaded_in_4bit", False) and not is_sharded_qlora:
peft_module_casting_to_bf16(model)
return model