def get_model()

in src/autotrain/trainers/clm/utils.py [0:0]


def get_model(config, tokenizer):
    """
    Loads and configures a language model based on the provided configuration and tokenizer.

    Args:
        config (Namespace): Configuration object containing model parameters and settings.
            - model (str): The model name or path.
            - token (str): Token for accessing the model.
            - unsloth (bool): Flag to determine if unsloth is used.
            - trainer (str): Type of trainer to use.
            - target_modules (str): Target modules for unsloth.
            - peft (bool): Flag to determine if PEFT (Parameter-Efficient Fine-Tuning) is used.
            - quantization (str): Quantization type, either "int4" or "int8".
            - mixed_precision (str): Mixed precision type, either "fp16" or "bf16".
            - block_size (int): Maximum sequence length.
            - lora_r (int): LoRA rank.
            - lora_alpha (int): LoRA alpha.
            - lora_dropout (float): LoRA dropout rate.
            - seed (int): Random seed.
            - disable_gradient_checkpointing (bool): Flag to disable gradient checkpointing.
            - use_flash_attention_2 (bool): Flag to use flash attention 2.
        tokenizer (PreTrainedTokenizer): Tokenizer to use with the model.

    Returns:
        PreTrainedModel: The configured language model.

    Raises:
        ImportError: If unsloth is not available when required.
    """
    model_config = AutoConfig.from_pretrained(
        config.model,
        token=config.token,
        trust_remote_code=ALLOW_REMOTE_CODE,
    )
    model_type = model_config.model_type
    unsloth_target_modules = None
    can_use_unloth = False

    if config.unsloth and is_unsloth_available() and config.trainer in ("default", "sft"):
        can_use_unloth = True

    if model_type in ("llama", "mistral", "gemma", "qwen2") and config.unsloth:
        if config.target_modules.strip().lower() == "all-linear":
            unsloth_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
        else:
            unsloth_target_modules = get_target_modules(config)
    else:
        can_use_unloth = False

    logger.info(f"Can use unsloth: {can_use_unloth}")
    if can_use_unloth:
        from unsloth import FastLanguageModel

        load_in_4bit = False
        load_in_8bit = False
        if config.peft and config.quantization == "int4":
            load_in_4bit = True
        elif config.peft and config.quantization == "int8":
            load_in_8bit = True

        dtype = None
        if config.mixed_precision == "fp16":
            dtype = torch.float16
        elif config.mixed_precision == "bf16":
            dtype = torch.bfloat16

        model, _ = FastLanguageModel.from_pretrained(
            model_name=config.model,
            token=config.token,
            trust_remote_code=ALLOW_REMOTE_CODE,
            load_in_4bit=load_in_4bit,
            load_in_8bit=load_in_8bit,
            max_seq_length=config.block_size,
            dtype=dtype,
        )
        if config.peft:
            model = FastLanguageModel.get_peft_model(
                model,
                r=config.lora_r,
                target_modules=unsloth_target_modules,
                lora_alpha=config.lora_alpha,
                lora_dropout=config.lora_dropout,
                bias="none",
                use_gradient_checkpointing="unsloth",
                random_state=config.seed,
                max_seq_length=config.block_size,
                use_rslora=False,
                loftq_config=None,
            )
        return model
    else:
        logger.warning("Unsloth not available, continuing without it...")

    logger.info("loading model config...")
    model_config = AutoConfig.from_pretrained(
        config.model,
        token=config.token,
        trust_remote_code=ALLOW_REMOTE_CODE,
        use_cache=config.disable_gradient_checkpointing,
    )

    logger.info("loading model...")
    if config.peft:
        if config.quantization == "int4":
            bnb_config = BitsAndBytesConfig(
                load_in_4bit=True,
                bnb_4bit_quant_type="nf4",
                bnb_4bit_compute_dtype=torch.float16,
                bnb_4bit_use_double_quant=False,
            )
        elif config.quantization == "int8":
            bnb_config = BitsAndBytesConfig(load_in_8bit=True)
        else:
            bnb_config = None

        model = AutoModelForCausalLM.from_pretrained(
            config.model,
            config=model_config,
            token=config.token,
            quantization_config=bnb_config,
            trust_remote_code=ALLOW_REMOTE_CODE,
            use_flash_attention_2=config.use_flash_attention_2,
        )
    else:
        model = AutoModelForCausalLM.from_pretrained(
            config.model,
            config=model_config,
            token=config.token,
            trust_remote_code=ALLOW_REMOTE_CODE,
            use_flash_attention_2=config.use_flash_attention_2,
        )

    logger.info(f"model dtype: {model.dtype}")
    model.resize_token_embeddings(len(tokenizer))

    if config.trainer != "default":
        return model

    if config.peft:
        logger.info("preparing peft model...")
        if config.quantization is not None:
            gradient_checkpointing_kwargs = {}
            if not config.disable_gradient_checkpointing:
                if config.quantization in ("int4", "int8"):
                    gradient_checkpointing_kwargs = {"use_reentrant": True}
                else:
                    gradient_checkpointing_kwargs = {"use_reentrant": False}
            model = prepare_model_for_kbit_training(
                model,
                use_gradient_checkpointing=not config.disable_gradient_checkpointing,
                gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
            )
        else:
            model.enable_input_require_grads()

        peft_config = LoraConfig(
            r=config.lora_r,
            lora_alpha=config.lora_alpha,
            lora_dropout=config.lora_dropout,
            bias="none",
            task_type="CAUSAL_LM",
            target_modules=get_target_modules(config),
        )
        model = get_peft_model(model, peft_config)

    return model