in scripts/ft_gemma3n_image_vt.py [0:0]
def main():
model_id = "google/gemma-3n-E2B-it"
processor = Gemma3nProcessor.from_pretrained(model_id)
# load the dataset
dataset_id = "ariG23498/intersection-dataset"
train_dataset = load_dataset(dataset_id, split="train")
val_dataset = load_dataset(dataset_id, split="validation")
# create data loader
partial_collate_fn = partial(collate_fn, processor=processor)
train_dataloader = DataLoader(
train_dataset,
batch_size=2,
shuffle=True,
num_workers=8,
drop_last=True,
collate_fn=partial_collate_fn,
pin_memory=True,
)
val_dataloader = DataLoader(
val_dataset,
batch_size=2,
shuffle=False,
num_workers=8,
drop_last=True,
collate_fn=partial_collate_fn,
)
# load the model and optimizer
model = Gemma3nForConditionalGeneration.from_pretrained(model_id).to("cuda")
run_inference(val_dataset, processor, model, "pred_before.png")
model = freeze_layers(model)
params_to_train = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.AdamW(params_to_train, lr=1e-5)
# Start Training
accumulation_steps = 8
for idx, batch in tqdm(enumerate(train_dataloader)):
outputs = model(**batch.to(model.device))
loss = outputs.loss / accumulation_steps
if idx % 50 == 0:
val_loss = 0.0
with torch.no_grad():
count = 0
for val_batch in val_dataloader:
val_loss = val_loss + model(**val_batch.to(model.device)).loss
count = count + 1
val_loss = val_loss / count
print(
f"Iter: {idx} Loss: {loss.item():.4f} Val Loss: {val_loss.item():.4f}"
)
run_inference(val_dataset, processor, model, f"infer_{idx}.png")
loss.backward()
if idx % 8 == 0:
optimizer.step()
optimizer.zero_grad()