use-cases/model-fine-tuning-pipeline/fine-tuning/pytorch/src/fine_tune.py (261 lines of code) (raw):

# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datasets import logging import logging.config import glob import os import signal import sys import torch import transformers import random from accelerate import Accelerator from datasets import Dataset, load_dataset, load_from_disk from peft import LoraConfig, PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer from trl import DataCollatorForCompletionOnlyLM, SFTConfig, SFTTrainer import torch.distributed as dist def graceful_shutdown(signal_number, stack_frame): signal_name = signal.Signals(signal_number).name logger.info(f"Received {signal_name}({signal_number}), shutting down...") # TODO: Add logic to handled checkpointing if required sys.exit(0) def formatting_prompts_func(example): output_texts = [] for i in range(len(example["prompt"])): text = f"""{example["prompt"][i]}\n{EOS_TOKEN}""" output_texts.append(text) return {"prompts": output_texts} def get_current_node_id_and_rank(): if dist.is_initialized(): logger.info("Distributed training enabled.") logger.info("Calculating node id") global_rank = dist.get_rank() # Get the process's rank logger.info(f"global_rank: {global_rank}") gpu_per_node = torch.cuda.device_count() logger.info(f"gpu_per_node: {gpu_per_node}") total_gpus = accelerator.state.num_processes logger.info(f"total_gpus: {total_gpus}") total_nodes = int(total_gpus / gpu_per_node) logger.info(f"total_nodes: {total_nodes}") node_id = global_rank // gpu_per_node else: logger.info("Distributed training enabled.") node_id = 0 global_rank = 0 logger.info(f"node_id: {node_id}") return (node_id, global_rank, gpu_per_node) if __name__ == "__main__": # Configure logging logging.config.fileConfig("logging.conf") logger = logging.getLogger("finetune") if "LOG_LEVEL" in os.environ: new_log_level = os.environ["LOG_LEVEL"].upper() logger.info( f"Log level set to '{new_log_level}' via LOG_LEVEL environment variable" ) logging.getLogger().setLevel(new_log_level) logger.setLevel(new_log_level) datasets.disable_progress_bar() transformers.utils.logging.disable_progress_bar() logger.info("Configure signal handlers") signal.signal(signal.SIGINT, graceful_shutdown) signal.signal(signal.SIGTERM, graceful_shutdown) accelerator = Accelerator() if "MLFLOW_ENABLE" in os.environ and os.environ["MLFLOW_ENABLE"] == "true": import mlflow remote_server_uri = os.environ["MLFLOW_TRACKING_URI"] mlflow.set_tracking_uri(remote_server_uri) experiment_name = os.environ["EXPERIMENT"] if accelerator.is_main_process: # Only the main process sets the experiment # Check if the experiment already exists experiment = mlflow.get_experiment_by_name(experiment_name) if experiment is None: # Experiment doesn't exist, create it try: experiment_id = mlflow.create_experiment(experiment_name) experiment = mlflow.get_experiment(experiment_id) print( f"Created new experiment: {experiment.name} (ID: {experiment.experiment_id})" ) except Exception as ex: logger.error(f"Create experiment failed: {ex}") else: # Experiment already exists, use it logger.info( f"Using existing experiment: {experiment.name} (ID: {experiment.experiment_id})" ) # Barrier to ensure all processes wait until the experiment is created accelerator.wait_for_everyone() experiment = mlflow.get_experiment_by_name(experiment_name) # Get the node ID node, rank, gpu_per_node = get_current_node_id_and_rank() node_id = "node_" + str(node) logger.info(f"Training at: {node_id} process: {rank}") process_id = "process_" + str(rank) # Set MLflow experiment and node ID (within the appropriate run context) mlflow.set_experiment(experiment_name) mlflow.set_system_metrics_node_id(node_id) # Check and create/reuse runs # Only one process per node creates a run to capture system metrics if (rank % gpu_per_node) == 0: existing_run = mlflow.search_runs( experiment_ids=[experiment.experiment_id], filter_string=f"tags.mlflow.runName = '{process_id}'", ) # No existing run with this name if len(existing_run) == 0: run = mlflow.start_run( run_name=node_id, experiment_id=experiment.experiment_id, ) has_active_ml_flow_run = True else: logger.info(f"Run with name '{process_id}' already exists") client = mlflow.MlflowClient() data = client.get_run(mlflow.active_run().info.run_id).data logger.info(f"Active run details: '{data}'") # logging of model parameters, metrics, and artifacts mlflow.autolog() # The bucket which contains the training data training_data_bucket = os.environ["TRAINING_DATASET_BUCKET"] training_data_path = os.environ["TRAINING_DATASET_PATH"] # The model that you want to train from the Hugging Face hub model_name = os.environ["MODEL_NAME"] # Fine-tuned model name new_model = os.environ["NEW_MODEL"] # The root path of where the fine-tuned model will be saved save_model_path = os.environ["MODEL_PATH"] # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "right" # Fix weird overflow issue with fp16 training EOS_TOKEN = tokenizer.eos_token training_dataset = load_from_disk( f"gs://{training_data_bucket}/{training_data_path}" ) logger.info("Data Formatting Started") input_data = training_dataset.map(formatting_prompts_func, batched=True) logger.info("Data Formatting Completed") INPUT_OUTPUT_DELIMITER = "<start_of_turn>model" collator = DataCollatorForCompletionOnlyLM( INPUT_OUTPUT_DELIMITER, tokenizer=tokenizer ) ################################################################################ # QLoRA parameters ################################################################################ # LoRA attention dimension lora_r = int(os.getenv("LORA_R", "8")) # Alpha parameter for LoRA scaling lora_alpha = int(os.getenv("LORA_ALPHA", "16")) # Dropout probability for LoRA layers lora_dropout = float(os.getenv("LORA_DROPOUT", "0.1")) ################################################################################ # TrainingArguments parameters ################################################################################ # Output directory where the model predictions and checkpoints will be stored # output_dir = "./results" # Number of training epochs num_train_epochs = int(os.getenv("EPOCHS", "1")) # Enable fp16/bf16 training (set bf16 to True with an A100) fp16 = False bf16 = False # Batch size per GPU for training per_device_train_batch_size = int(os.getenv("TRAIN_BATCH_SIZE", "1")) # Batch size per GPU for evaluation per_device_eval_batch_size = 1 # Number of update steps to accumulate the gradients for gradient_accumulation_steps = int(os.getenv("GRADIENT_ACCUMULATION_STEPS", "1")) # Enable gradient checkpointing gradient_checkpointing = True # Maximum gradient normal (gradient clipping) max_grad_norm = float(os.getenv("MAX_GRAD_NORM", "0.3")) # Initial learning rate (AdamW optimizer) learning_rate = float(os.getenv("LEARNING_RATE", "2e-4")) # Weight decay to apply to all layers except bias/LayerNorm weights weight_decay = float(os.getenv("WEIGHT_DECAY", "0.001")) # Optimizer to use optim = "paged_adamw_32bit" # Learning rate schedule lr_scheduler_type = "cosine" # Number of training steps (overrides num_train_epochs) max_steps = -1 # Ratio of steps for a linear warmup (from 0 to learning rate) warmup_ratio = float(os.getenv("WARMUP_RATIO", "0.03")) # Group sequences into batches with same length # Saves memory and speeds up training considerably group_by_length = True # Save strategy: steps, epoch, no save_strategy = os.getenv("CHECKPOINT_SAVE_STRATEGY", "steps") # Save total limit of checkpoints save_total_limit = int(os.getenv("CHECKPOINT_SAVE_TOTAL_LIMIT", "5")) # Save checkpoint every X updates steps save_steps = int(os.getenv("CHECKPOINT_SAVE_STEPS", "1000")) # Log every X updates steps logging_steps = 50 ################################################################################ # SFT parameters ################################################################################ # Maximum sequence length to use max_seq_length = int(os.getenv("MAX_SEQ_LENGTH", "512")) # Pack multiple short examples in the same input sequence to increase efficiency packing = False # Load base model logger.info("Loading base model started") model = AutoModelForCausalLM.from_pretrained( attn_implementation="eager", pretrained_model_name_or_path=model_name, torch_dtype=torch.bfloat16, ) model.config.use_cache = False model.config.pretraining_tp = 1 logger.info("Loading base model completed") logger.info("Configuring fine tuning started") # Load LoRA configuration peft_config = LoraConfig( lora_alpha=lora_alpha, lora_dropout=lora_dropout, r=lora_r, bias="none", task_type="CAUSAL_LM", target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], ) # Set training parameters training_arguments = SFTConfig( bf16=bf16, dataset_kwargs={ "add_special_tokens": False, # We template with special tokens "append_concat_token": False, # No need to add additional separator token }, dataset_text_field="prompts", disable_tqdm=True, fp16=fp16, gradient_accumulation_steps=gradient_accumulation_steps, gradient_checkpointing=gradient_checkpointing, gradient_checkpointing_kwargs={"use_reentrant": False}, group_by_length=group_by_length, log_on_each_node=False, logging_steps=logging_steps, learning_rate=learning_rate, lr_scheduler_type=lr_scheduler_type, max_grad_norm=max_grad_norm, max_seq_length=max_seq_length, max_steps=max_steps, num_train_epochs=num_train_epochs, optim=optim, output_dir=save_model_path, packing=packing, per_device_train_batch_size=per_device_train_batch_size, save_strategy=save_strategy, save_steps=save_steps, save_total_limit=save_total_limit, warmup_ratio=warmup_ratio, weight_decay=weight_decay, ) logger.info("Configuring fine tuning completed") logger.info("Creating trainer started") trainer = SFTTrainer( args=training_arguments, data_collator=collator, model=model, peft_config=peft_config, tokenizer=tokenizer, train_dataset=input_data, ) logger.info("Creating trainer completed") logger.info("Fine tuning started") # Check for existing checkpoints checkpoints_present = glob.glob(f"{save_model_path}/checkpoint-*") != [] trainer.train(resume_from_checkpoint=checkpoints_present) logger.info("Fine tuning completed") if "MLFLOW_ENABLE" in os.environ and os.environ["MLFLOW_ENABLE"] == "true": if accelerator.is_main_process: # register the model only at main process mv = mlflow.register_model( model_uri=f"gs://{training_data_bucket}/{save_model_path}", name=new_model, ) logger.info(f"Name: {mv.name}") logger.info(f"Version: {mv.version}") logger.info("Saving new model started") trainer.model.save_pretrained(new_model) logger.info("Saving new model completed") logger.info("Merging the new model with base model started") # Reload model in FP16 and merge it with LoRA weights base_model = AutoModelForCausalLM.from_pretrained( low_cpu_mem_usage=True, pretrained_model_name_or_path=model_name, return_dict=True, torch_dtype=torch.bfloat16, ) model = PeftModel.from_pretrained( model=base_model, model_id=new_model, ) model = model.merge_and_unload() logger.info("Merging the new model with base model completed") logger.info("Accelerate unwrap model started") unwrapped_model = accelerator.unwrap_model(model) logger.info("Accelerate unwrap model completed") logger.info("Save unwrapped model started") unwrapped_model.save_pretrained( is_main_process=accelerator.is_main_process, save_directory=save_model_path, save_function=accelerator.save, ) logger.info("Save unwrapped model completed") logger.info("Save new tokenizer started") if accelerator.is_main_process: tokenizer.save_pretrained(save_model_path) logger.info("Save new tokenizer completed") # Barrier to ensure all processes wait until 'unwrapped model save' is completed accelerator.wait_for_everyone() logger.info("Script completed")