in src/autotrain/trainers/vlm/utils.py [0:0]
def get_model(config):
logger.info("loading model config...")
model_config = AutoConfig.from_pretrained(
config.model,
token=config.token,
trust_remote_code=ALLOW_REMOTE_CODE,
use_cache=config.disable_gradient_checkpointing,
)
logger.info("loading model...")
if config.peft:
if config.quantization == "int4":
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=False,
)
elif config.quantization == "int8":
bnb_config = BitsAndBytesConfig(load_in_8bit=True)
else:
bnb_config = None
model = PaliGemmaForConditionalGeneration.from_pretrained(
config.model,
config=model_config,
token=config.token,
quantization_config=bnb_config,
trust_remote_code=ALLOW_REMOTE_CODE,
)
else:
model = PaliGemmaForConditionalGeneration.from_pretrained(
config.model,
config=model_config,
token=config.token,
trust_remote_code=ALLOW_REMOTE_CODE,
)
logger.info(f"model dtype: {model.dtype}")
if config.peft:
logger.info("preparing peft model...")
if config.quantization is not None:
gradient_checkpointing_kwargs = {}
if not config.disable_gradient_checkpointing:
if config.quantization in ("int4", "int8"):
gradient_checkpointing_kwargs = {"use_reentrant": True}
else:
gradient_checkpointing_kwargs = {"use_reentrant": False}
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=not config.disable_gradient_checkpointing,
gradient_checkpointing_kwargs=gradient_checkpointing_kwargs,
)
else:
model.enable_input_require_grads()
peft_config = LoraConfig(
r=config.lora_r,
lora_alpha=config.lora_alpha,
lora_dropout=config.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=get_target_modules(config),
)
model = get_peft_model(model, peft_config)
for param in model.vision_tower.parameters():
param.requires_grad = False
for param in model.multi_modal_projector.parameters():
param.requires_grad = False
return model