in sagemaker/28_train_llms_with_qlora/scripts/run_clm.py [0:0]
def create_peft_config(model, gradient_checkpointing=True):
from peft import (
get_peft_model,
LoraConfig,
TaskType,
prepare_model_for_kbit_training,
)
peft_config = LoraConfig(
r=64,
lora_alpha=16,
target_modules=[
"query_key_value",
"dense",
"dense_h_to_4h",
"dense_4h_to_h",
],
lora_dropout=0.1,
bias="none",
task_type=TaskType.CAUSAL_LM,
)
# prepare int-4 model for training
model = prepare_model_for_kbit_training(model)
if gradient_checkpointing:
model.gradient_checkpointing_enable()
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()
return model