def main()

in scripts/ft_gemma3n_image_vt.py [0:0]


def main():
    model_id = "google/gemma-3n-E2B-it"
    processor = Gemma3nProcessor.from_pretrained(model_id)

    # load the dataset
    dataset_id = "ariG23498/intersection-dataset"
    train_dataset = load_dataset(dataset_id, split="train")
    val_dataset = load_dataset(dataset_id, split="validation")

    # create data loader
    partial_collate_fn = partial(collate_fn, processor=processor)
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=2,
        shuffle=True,
        num_workers=8,
        drop_last=True,
        collate_fn=partial_collate_fn,
        pin_memory=True,
    )
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=2,
        shuffle=False,
        num_workers=8,
        drop_last=True,
        collate_fn=partial_collate_fn,
    )

    # load the model and optimizer
    model = Gemma3nForConditionalGeneration.from_pretrained(model_id).to("cuda")

    run_inference(val_dataset, processor, model, "pred_before.png")

    model = freeze_layers(model)

    params_to_train = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.AdamW(params_to_train, lr=1e-5)

    # Start Training
    accumulation_steps = 8
    for idx, batch in tqdm(enumerate(train_dataloader)):
        outputs = model(**batch.to(model.device))
        loss = outputs.loss / accumulation_steps
        if idx % 50 == 0:
            val_loss = 0.0
            with torch.no_grad():
                count = 0
                for val_batch in val_dataloader:
                    val_loss = val_loss + model(**val_batch.to(model.device)).loss
                    count = count + 1
                val_loss = val_loss / count
            print(
                f"Iter: {idx} Loss: {loss.item():.4f} Val Loss: {val_loss.item():.4f}"
            )
            run_inference(val_dataset, processor, model, f"infer_{idx}.png")

        loss.backward()
        if idx % 8 == 0:
            optimizer.step()
            optimizer.zero_grad()