scripts/ft_gemma3n_image_vt.py (152 lines of code) (raw):

import os os.environ["TRANSFORMERS_VERBOSITY"] = "error" os.environ["TOKENIZERS_PARALLELISM"] = "false" import random from functools import partial import torch from datasets import load_dataset from matplotlib import pyplot as plt from torch.utils.data import DataLoader from tqdm import tqdm from transformers import Gemma3nForConditionalGeneration, Gemma3nProcessor def collate_fn(examples, processor): messages = list() for sample in examples: image = sample["image"].convert("RGB") label = str(sample["label"]) message = [ { "role": "system", "content": [ { "type": "text", "text": "You are an assistant with great geometry skills.", } ], }, { "role": "user", "content": [ {"type": "image", "image": image}, { "type": "text", "text": "How many intersection points are there in the image?", }, ], }, {"role": "assistant", "content": [{"type": "text", "text": label}]}, ] messages.append(message) batch = processor.apply_chat_template( messages, add_generation_prompt=False, tokenize=True, return_dict=True, return_tensors="pt", ) labels = batch["input_ids"].clone() # Clone input IDs for labels # Mask the tokens that we do not want to include in the loss computation # -100 is ignored during categorical cross entropy loss computation labels[labels == processor.tokenizer.pad_token_id] = -100 labels[labels == processor.tokenizer.image_token_id] = -100 labels[labels == processor.tokenizer.boi_token_id] = -100 labels[labels == processor.tokenizer.eoi_token_id] = -100 batch["labels"] = labels return batch def freeze_layers(model): for name, param in model.named_parameters(): if "attn" in name: param.requires_grad = True else: param.requires_grad = False return model def run_inference(val_dataset, processor, model, fname): # infer before training val_sample = random.choice(val_dataset) image = val_sample["image"].convert("RGB") message = [ { "role": "system", "content": [ { "type": "text", "text": "You are an assistant with great geometry skills.", } ], }, { "role": "user", "content": [ {"type": "image", "image": image}, { "type": "text", "text": "How many intersection points are there in the image?", }, ], }, ] inputs = processor.apply_chat_template( message, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", ).to(model.device, dtype=torch.bfloat16) input_len = inputs["input_ids"].shape[-1] with torch.no_grad(): generation = model.generate(**inputs, max_new_tokens=10, disable_compile=True) generation = generation[0][input_len:] decoded = processor.decode(generation, skip_special_tokens=True) plt.imshow(image) plt.axis("off") plt.title(f"Pred: {decoded}") plt.show() plt.savefig(f"outputs_fine_tune/{fname}") def main(): model_id = "google/gemma-3n-E2B-it" processor = Gemma3nProcessor.from_pretrained(model_id) # load the dataset dataset_id = "ariG23498/intersection-dataset" train_dataset = load_dataset(dataset_id, split="train") val_dataset = load_dataset(dataset_id, split="validation") # create data loader partial_collate_fn = partial(collate_fn, processor=processor) train_dataloader = DataLoader( train_dataset, batch_size=2, shuffle=True, num_workers=8, drop_last=True, collate_fn=partial_collate_fn, pin_memory=True, ) val_dataloader = DataLoader( val_dataset, batch_size=2, shuffle=False, num_workers=8, drop_last=True, collate_fn=partial_collate_fn, ) # load the model and optimizer model = Gemma3nForConditionalGeneration.from_pretrained(model_id).to("cuda") run_inference(val_dataset, processor, model, "pred_before.png") model = freeze_layers(model) params_to_train = filter(lambda p: p.requires_grad, model.parameters()) optimizer = torch.optim.AdamW(params_to_train, lr=1e-5) # Start Training accumulation_steps = 8 for idx, batch in tqdm(enumerate(train_dataloader)): outputs = model(**batch.to(model.device)) loss = outputs.loss / accumulation_steps if idx % 50 == 0: val_loss = 0.0 with torch.no_grad(): count = 0 for val_batch in val_dataloader: val_loss = val_loss + model(**val_batch.to(model.device)).loss count = count + 1 val_loss = val_loss / count print( f"Iter: {idx} Loss: {loss.item():.4f} Val Loss: {val_loss.item():.4f}" ) run_inference(val_dataset, processor, model, f"infer_{idx}.png") loss.backward() if idx % 8 == 0: optimizer.step() optimizer.zero_grad() if __name__ == "__main__": main()