in notebooks/text-generation/scripts/sft_finetuning_qwen3.py [0:0]
def training_function(script_args, training_args):
tokenizer = AutoTokenizer.from_pretrained(script_args.model_id)
tokenizer.pad_token = tokenizer.eos_token
dataset = get_dataset(tokenizer)
dtype = torch.bfloat16 if training_args.bf16 else torch.float32
model = Qwen3ForCausalLM.from_pretrained(
script_args.model_id,
training_args.trn_config,
torch_dtype=dtype,
use_flash_attention_2=script_args.use_flash_attention_2,
)
config = LoraConfig(
r=64,
lora_alpha=128,
lora_dropout=0.05,
target_modules=["embed_tokens", "q_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj", "gate_proj"],
bias="none",
task_type="CAUSAL_LM",
)
args = training_args.to_dict()
packing = True
# Note: max_seq_length must be a multiple of 2048 to use the flash attention 2 kernel
sft_config = NeuronSFTConfig(
max_seq_length=8192,
packing=packing,
**args,
)
def formatting_function(examples):
return tokenizer.apply_chat_template(examples["messages"], tokenize=False, add_generation_prompt=False)
trainer = NeuronSFTTrainer(
args=sft_config,
model=model,
peft_config=config,
tokenizer=tokenizer,
train_dataset=dataset,
formatting_func=formatting_function,
)
# Start training
train_result = trainer.train()
trainer.save_model() # Saves the tokenizer too for easy upload
metrics = train_result.metrics
xm.master_print(f"Model trained in {training_args.output_dir}")
xm.master_print(metrics)