def main()

in scripts/ft_gemma3n_audio_vt.py [0:0]


def main():
    model_id = "google/gemma-3n-E2B-it"
    processor = Gemma3nProcessor.from_pretrained(model_id)

    # Load and split the dataset.
    ds_full = load_dataset("AdrienB134/Emilia-dataset-french-split", split="fr")
    split_ds = ds_full.train_test_split(test_size=0.1, seed=42)
    train_dataset = split_ds["train"].select(range(10000))
    val_dataset = split_ds["test"].select(range(100))

    # create data loader
    partial_collate_fn = partial(collate_fn, processor=processor)
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=1,
        shuffle=True,
        num_workers=8,
        drop_last=True,
        collate_fn=partial_collate_fn,
        pin_memory=True,
    )
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=8,
        drop_last=True,
        collate_fn=partial_collate_fn,
    )

    # load the model and optimizer
    model = Gemma3nForConditionalGeneration.from_pretrained(model_id).to(
        "cuda", dtype=torch.bfloat16
    )

    run_inference(val_dataset, processor, model, "pred_before.png")

    model = freeze_layers(model)

    params_to_train = filter(lambda p: p.requires_grad, model.parameters())
    optimizer = torch.optim.AdamW(params_to_train, lr=1e-5)

    # Start Training
    accumulation_steps = 8
    for idx, batch in tqdm(enumerate(train_dataloader)):
        outputs = model(**batch.to(model.device, dtype=torch.bfloat16))
        loss = outputs.loss / accumulation_steps
        if idx % 100 == 0:
            val_loss = 0.0
            with torch.no_grad():
                count = 0
                for val_batch in tqdm(val_dataloader, desc="Validation"):
                    val_loss = (
                        val_loss
                        + model(**val_batch.to(model.device, dtype=torch.bfloat16)).loss
                    )
                    count = count + 1
                val_loss = val_loss / count
            print(
                f"Iter: {idx} Loss: {loss.item():.4f} Val Loss: {val_loss.item():.4f}"
            )
            run_inference(val_dataset, processor, model, f"infer_{idx}.png")

        loss.backward()
        if idx % 8 == 0:
            optimizer.step()
            optimizer.zero_grad()