in src/autotrain/trainers/clm/utils.py [0:0]
def get_model(config, tokenizer):
"""
Loads and configures a language model based on the provided configuration and tokenizer.
Args:
config (Namespace): Configuration object containing model parameters and settings.
- model (str): The model name or path.
- token (str): Token for accessing the model.
- unsloth (bool): Flag to determine if unsloth is used.
- trainer (str): Type of trainer to use.
- target_modules (str): Target modules for unsloth.
- peft (bool): Flag to determine if PEFT (Parameter-Efficient Fine-Tuning) is used.
- quantization (str): Quantization type, either "int4" or "int8".
- mixed_precision (str): Mixed precision type, either "fp16" or "bf16".
- block_size (int): Maximum sequence length.
- lora_r (int): LoRA rank.
- lora_alpha (int): LoRA alpha.
- lora_dropout (float): LoRA dropout rate.
- seed (int): Random seed.
- disable_gradient_checkpointing (bool): Flag to disable gradient checkpointing.
- use_flash_attention_2 (bool): Flag to use flash attention 2.
tokenizer (PreTrainedTokenizer): Tokenizer to use with the model.
Returns:
PreTrainedModel: The configured language model.
Raises:
ImportError: If unsloth is not available when required.
"""
model_config = AutoConfig.from_pretrained(
config.model,
token=config.token,
trust_remote_code=ALLOW_REMOTE_CODE,
)
model_type = model_config.model_type
unsloth_target_modules = None
can_use_unloth = False
if config.unsloth and is_unsloth_available() and config.trainer in ("default", "sft"):
can_use_unloth = True
if model_type in ("llama", "mistral", "gemma", "qwen2") and config.unsloth:
if config.target_modules.strip().lower() == "all-linear":
unsloth_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]
else:
unsloth_target_modules = get_target_modules(config)
else:
can_use_unloth = False
logger.info(f"Can use unsloth: {can_use_unloth}")
if can_use_unloth:
from unsloth import FastLanguageModel
load_in_4bit = False
load_in_8bit = False
if config.peft and config.quantization == "int4":
load_in_4bit = True
elif config.peft and config.quantization == "int8":
load_in_8bit = True
dtype = None
if config.mixed_precision == "fp16":
dtype = torch.float16
elif config.mixed_precision == "bf16":
dtype = torch.bfloat16
model, _ = FastLanguageModel.from_pretrained(
model_name=config.model,
token=config.token,
trust_remote_code=ALLOW_REMOTE_CODE,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
max_seq_length=config.block_size,
dtype=dtype,
)
if config.peft:
model = FastLanguageModel.get_peft_model(
model,
r=config.lora_r,
target_modules=unsloth_target_modules,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=config.seed,
max_seq_length=config.block_size,
use_rslora=False,
loftq_config=None,
)
return model
else:
logger.warning("Unsloth not available, continuing without it...")
logger.info("loading model config...")
model_config = AutoConfig.from_pretrained(
config.model,
token=config.token,
trust_remote_code=ALLOW_REMOTE_CODE,
use_cache=config.disable_gradient_checkpointing,
)
logger.info("loading model...")
if config.peft:
if config.quantization == "int4":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=False,
)
elif config.quantization == "int8":
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
else:
bnb_config = None
model = AutoModelForCausalLM.from_pretrained(
config.model,
config=model_config,
token=config.token,
quantization_config=bnb_config,
trust_remote_code=ALLOW_REMOTE_CODE,
use_flash_attention_2=config.use_flash_attention_2,
)
else:
model = AutoModelForCausalLM.from_pretrained(
config.model,
config=model_config,
token=config.token,
trust_remote_code=ALLOW_REMOTE_CODE,
use_flash_attention_2=config.use_flash_attention_2,
)
logger.info(f"model dtype: {model.dtype}")
model.resize_token_embeddings(len(tokenizer))
if config.trainer != "default":
return model
if config.peft:
logger.info("preparing peft model...")
if config.quantization is not None:
gradient_checkpointing_kwargs = {}
if not config.disable_gradient_checkpointing:
if config.quantization in ("int4", "int8"):
gradient_checkpointing_kwargs = {"use_reentrant": True}
else:
gradient_checkpointing_kwargs = {"use_reentrant": False}
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=not config.disable_gradient_checkpointing,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
)
else:
model.enable_input_require_grads()
peft_config = LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=get_target_modules(config),
)
model = get_peft_model(model, peft_config)
return model