notebooks/text-generation/scripts/sft_finetuning_qwen3.py (92 lines of code) (raw):

from dataclasses import dataclass, field import torch import torch_xla.core.xla_model as xm from datasets import load_dataset from peft import LoraConfig from transformers import ( AutoTokenizer, set_seed, ) from optimum.neuron import NeuronHfArgumentParser as HfArgumentParser from optimum.neuron import NeuronSFTConfig, NeuronSFTTrainer, NeuronTrainingArguments from optimum.neuron.models.training import Qwen3ForCausalLM def get_dataset(tokenizer): dataset_id = "tengomucho/simple_recipes" recipes = load_dataset(dataset_id, split="train") recipes = recipes.flatten() def preprocess_function(examples): recipes = examples["recipes"] names = examples["names"] chats = [] for recipe, name in zip(recipes, names): # Append the EOS token to the response recipe += tokenizer.eos_token chat = [ { "role": "user", "content": f"How can I make {name}?", }, {"role": "assistant", "content": recipe}, ] tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=False) chats.append(chat) return {"messages": chats} dataset = recipes.map(preprocess_function, batched=True, remove_columns=recipes.column_names) return dataset def training_function(script_args, training_args): tokenizer = AutoTokenizer.from_pretrained(script_args.model_id) tokenizer.pad_token = tokenizer.eos_token dataset = get_dataset(tokenizer) dtype = torch.bfloat16 if training_args.bf16 else torch.float32 model = Qwen3ForCausalLM.from_pretrained( script_args.model_id, training_args.trn_config, torch_dtype=dtype, use_flash_attention_2=script_args.use_flash_attention_2, ) config = LoraConfig( r=64, lora_alpha=128, lora_dropout=0.05, target_modules=["embed_tokens", "q_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj", "gate_proj"], bias="none", task_type="CAUSAL_LM", ) args = training_args.to_dict() packing = True # Note: max_seq_length must be a multiple of 2048 to use the flash attention 2 kernel sft_config = NeuronSFTConfig( max_seq_length=8192, packing=packing, **args, ) def formatting_function(examples): return tokenizer.apply_chat_template(examples["messages"], tokenize=False, add_generation_prompt=False) trainer = NeuronSFTTrainer( args=sft_config, model=model, peft_config=config, tokenizer=tokenizer, train_dataset=dataset, formatting_func=formatting_function, ) # Start training train_result = trainer.train() trainer.save_model() # Saves the tokenizer too for easy upload metrics = train_result.metrics xm.master_print(f"Model trained in {training_args.output_dir}") xm.master_print(metrics) @dataclass class ScriptArguments: model_id: str = field( default="Qwen/Qwen3-8B", metadata={"help": "The model that you want to train from the Hugging Face hub."}, ) use_flash_attention_2: bool = field( default=True, metadata={"help": "Whether to use Flash Attention 2."}, ) def main(): parser = HfArgumentParser([ScriptArguments, NeuronTrainingArguments]) script_args, training_args = parser.parse_args_into_dataclasses() # set seed set_seed(training_args.seed) # run training function training_function(script_args, training_args) if __name__ == "__main__": main()