def create_peft_config()

in sagemaker/28_train_llms_with_qlora/scripts/run_clm.py [0:0]


def create_peft_config(model, gradient_checkpointing=True):
    from peft import (
        get_peft_model,
        LoraConfig,
        TaskType,
        prepare_model_for_kbit_training,
    )

    peft_config = LoraConfig(
        r=64,
        lora_alpha=16,
        target_modules=[
            "query_key_value",
            "dense",
            "dense_h_to_4h",
            "dense_4h_to_h",
        ],
        lora_dropout=0.1,
        bias="none",
        task_type=TaskType.CAUSAL_LM,
    )

    # prepare int-4 model for training
    model = prepare_model_for_kbit_training(model)
    if gradient_checkpointing:
        model.gradient_checkpointing_enable()
    model = get_peft_model(model, peft_config)
    model.print_trainable_parameters()
    return model