scripts/knowledge_distillation_qwen2.yaml (70 lines of code) (raw):

# Model Arguments model: _component_: torchtune.models.qwen2.lora_qwen2_0_5b lora_attn_modules: ['q_proj', 'k_proj', 'v_proj'] apply_lora_to_mlp: False lora_rank: 32 lora_alpha: 64 teacher_model: _component_: torchtune.models.qwen2.qwen2_1_5b tokenizer: _component_: torchtune.models.qwen2.qwen2_tokenizer path: {{student_model_dir}}/vocab.json merges_file: {{student_model_dir}}/merges.txt max_seq_len: null checkpointer: _component_: torchtune.training.FullModelHFCheckpointer checkpoint_dir: {{student_model_dir}} checkpoint_files: [ model.safetensors ] recipe_checkpoint: null output_dir: {{model_output_dir}} model_type: QWEN2 teacher_checkpointer: _component_: torchtune.training.FullModelHFCheckpointer checkpoint_dir: {{teacher_model_dir}} checkpoint_files: [ model.safetensors ] recipe_checkpoint: null output_dir: teacher-tmp model_type: QWEN2 resume_from_checkpoint: False # Dataset and Sampler dataset: _component_: torchtune.datasets.instruct_dataset source: json data_files: {{train_path}} column_map: input: instruction output: output train_on_input: False packed: False split: train seed: null shuffle: True # Optimizer and Scheduler optimizer: _component_: torch.optim.AdamW weight_decay: 0.01 lr: 3e-4 lr_scheduler: _component_: torchtune.training.get_cosine_schedule_with_warmup num_warmup_steps: 100 loss: _component_: torchtune.modules.loss.CEWithChunkedOutputLoss kd_loss: _component_: torchtune.modules.loss.ForwardKLWithChunkedOutputLoss kd_ratio: 0.5 # Training epochs: 1 max_steps_per_epoch: null batch_size: 4 gradient_accumulation_steps: 4 compile: False # Logging output_dir: {{log_dir}}/kd_output metric_logger: _component_: torchtune.training.metric_logging.{{metric_logger}} log_dir: {{log_dir}}/training_logs log_every_n_steps: 1 log_peak_memory_stats: True # Environment device: cuda dtype: bf16 enable_activation_checkpointing: False