in scripts/ft_gemma3n_audio_vt.py [0:0]
def main():
model_id = "google/gemma-3n-E2B-it"
processor = Gemma3nProcessor.from_pretrained(model_id)
# Load and split the dataset.
ds_full = load_dataset("AdrienB134/Emilia-dataset-french-split", split="fr")
split_ds = ds_full.train_test_split(test_size=0.1, seed=42)
train_dataset = split_ds["train"].select(range(10000))
val_dataset = split_ds["test"].select(range(100))
# create data loader
partial_collate_fn = partial(collate_fn, processor=processor)
train_dataloader = DataLoader(
train_dataset,
batch_size=1,
shuffle=True,
num_workers=8,
drop_last=True,
collate_fn=partial_collate_fn,
pin_memory=True,
)
val_dataloader = DataLoader(
val_dataset,
batch_size=1,
shuffle=False,
num_workers=8,
drop_last=True,
collate_fn=partial_collate_fn,
)
# load the model and optimizer
model = Gemma3nForConditionalGeneration.from_pretrained(model_id).to(
"cuda", dtype=torch.bfloat16
)
run_inference(val_dataset, processor, model, "pred_before.png")
model = freeze_layers(model)
params_to_train = filter(lambda p: p.requires_grad, model.parameters())
optimizer = torch.optim.AdamW(params_to_train, lr=1e-5)
# Start Training
accumulation_steps = 8
for idx, batch in tqdm(enumerate(train_dataloader)):
outputs = model(**batch.to(model.device, dtype=torch.bfloat16))
loss = outputs.loss / accumulation_steps
if idx % 100 == 0:
val_loss = 0.0
with torch.no_grad():
count = 0
for val_batch in tqdm(val_dataloader, desc="Validation"):
val_loss = (
val_loss
+ model(**val_batch.to(model.device, dtype=torch.bfloat16)).loss
)
count = count + 1
val_loss = val_loss / count
print(
f"Iter: {idx} Loss: {loss.item():.4f} Val Loss: {val_loss.item():.4f}"
)
run_inference(val_dataset, processor, model, f"infer_{idx}.png")
loss.backward()
if idx % 8 == 0:
optimizer.step()
optimizer.zero_grad()