def training_function()

in notebooks/text-generation/scripts/sft_finetuning_qwen3.py [0:0]


def training_function(script_args, training_args):
    tokenizer = AutoTokenizer.from_pretrained(script_args.model_id)
    tokenizer.pad_token = tokenizer.eos_token

    dataset = get_dataset(tokenizer)

    dtype = torch.bfloat16 if training_args.bf16 else torch.float32
    model = Qwen3ForCausalLM.from_pretrained(
        script_args.model_id,
        training_args.trn_config,
        torch_dtype=dtype,
        use_flash_attention_2=script_args.use_flash_attention_2,
    )

    config = LoraConfig(
        r=64,
        lora_alpha=128,
        lora_dropout=0.05,
        target_modules=["embed_tokens", "q_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj", "gate_proj"],
        bias="none",
        task_type="CAUSAL_LM",
    )

    args = training_args.to_dict()
    packing = True
    # Note: max_seq_length must be a multiple of 2048 to use the flash attention 2 kernel
    sft_config = NeuronSFTConfig(
        max_seq_length=8192,
        packing=packing,
        **args,
    )

    def formatting_function(examples):
        return tokenizer.apply_chat_template(examples["messages"], tokenize=False, add_generation_prompt=False)

    trainer = NeuronSFTTrainer(
        args=sft_config,
        model=model,
        peft_config=config,
        tokenizer=tokenizer,
        train_dataset=dataset,
        formatting_func=formatting_function,
    )

    # Start training
    train_result = trainer.train()

    trainer.save_model()  # Saves the tokenizer too for easy upload
    metrics = train_result.metrics
    xm.master_print(f"Model trained in {training_args.output_dir}")
    xm.master_print(metrics)