peft_fine_tune.py (304 lines of code) (raw):

import gc import os import sys import threading import psutil import torch from accelerate import Accelerator from datasets import load_dataset from torch.utils.data import DataLoader from tqdm import tqdm from transformers import ( AutoModelForCausalLM, AutoTokenizer, default_data_collator, get_linear_schedule_with_warmup, set_seed, ) import time from peft import LoraConfig, TaskType, get_peft_model import numpy as np def levenshtein_distance(str1, str2): # TC: O(N^2) # SC: O(N) if str1 == str2: return 0 num_rows = len(str1) + 1 num_cols = len(str2) + 1 dp_matrix = list(range(num_cols)) for i in range(1, num_rows): prev = dp_matrix[0] dp_matrix[0] = i for j in range(1, num_cols): temp = dp_matrix[j] if str1[i - 1] == str2[j - 1]: dp_matrix[j] = prev else: dp_matrix[j] = min(prev, dp_matrix[j], dp_matrix[j - 1]) + 1 prev = temp return dp_matrix[num_cols - 1] def get_closest_label(eval_pred, classes): min_id = sys.maxsize min_edit_distance = sys.maxsize for i, class_label in enumerate(classes): edit_distance = levenshtein_distance(eval_pred.strip(), class_label) if edit_distance < min_edit_distance: min_id = i min_edit_distance = edit_distance return classes[min_id] # Converting Bytes to Megabytes def b2mb(x): return int(x / 2**20) # This context manager is used to track the peak memory usage of the process class TorchTracemalloc: def __enter__(self): gc.collect() torch.cuda.empty_cache() torch.cuda.reset_max_memory_allocated() # reset the peak gauge to zero self.begin = torch.cuda.memory_allocated() self.process = psutil.Process() self.cpu_begin = self.cpu_mem_used() self.peak_monitoring = True peak_monitor_thread = threading.Thread(target=self.peak_monitor_func) peak_monitor_thread.daemon = True peak_monitor_thread.start() return self def cpu_mem_used(self): """get resident set size memory for the current process""" return self.process.memory_info().rss def peak_monitor_func(self): self.cpu_peak = -1 while True: self.cpu_peak = max(self.cpu_mem_used(), self.cpu_peak) # can't sleep or will not catch the peak right (this comment is here on purpose) # time.sleep(0.001) # 1msec if not self.peak_monitoring: break def __exit__(self, *exc): self.peak_monitoring = False gc.collect() torch.cuda.empty_cache() self.end = torch.cuda.memory_allocated() self.peak = torch.cuda.max_memory_allocated() self.used = b2mb(self.end - self.begin) self.peaked = b2mb(self.peak - self.begin) self.cpu_end = self.cpu_mem_used() self.cpu_used = b2mb(self.cpu_end - self.cpu_begin) self.cpu_peaked = b2mb(self.cpu_peak - self.cpu_begin) # print(f"delta used/peak {self.used:4d}/{self.peaked:4d}") def main(): accelerator = Accelerator() model_name_or_path = "meta-llama/Meta-Llama-3-8B-Instruct" dataset_name = "twitter_complaints" peft_config = LoraConfig(task_type=TaskType.CAUSAL_LM, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1) text_column = "Tweet text" label_column = "text_label" lr = 3e-3 num_epochs = 1 batch_size = 2 seed = 42 max_length = 2048 do_test = False set_seed(seed) dataset = load_dataset("ought/raft", dataset_name) classes = [k.replace("_", " ") for k in dataset["train"].features["Label"].names] dataset = dataset.map( lambda x: {"text_label": [classes[label] for label in x["Label"]]}, batched=True, num_proc=1, ) tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) tokenizer.pad_token_id = tokenizer.eos_token_id def preprocess_function(examples): batch_size = len(examples[text_column]) inputs = [f"{text_column} : {x} Label : " for x in examples[text_column]] targets = [str(x) for x in examples[label_column]] model_inputs = tokenizer(inputs) labels = tokenizer(targets, add_special_tokens=False) # don't add bos token because we concatenate with inputs for i in range(batch_size): sample_input_ids = model_inputs["input_ids"][i] label_input_ids = labels["input_ids"][i] + [tokenizer.eos_token_id] model_inputs["input_ids"][i] = sample_input_ids + label_input_ids labels["input_ids"][i] = [-100] * len(sample_input_ids) + label_input_ids model_inputs["attention_mask"][i] = [1] * len(model_inputs["input_ids"][i]) for i in range(batch_size): sample_input_ids = model_inputs["input_ids"][i] label_input_ids = labels["input_ids"][i] model_inputs["input_ids"][i] = [tokenizer.pad_token_id] * ( max_length - len(sample_input_ids) ) + sample_input_ids model_inputs["attention_mask"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[ "attention_mask" ][i] labels["input_ids"][i] = [-100] * (max_length - len(sample_input_ids)) + label_input_ids model_inputs["input_ids"][i] = torch.tensor(model_inputs["input_ids"][i][:max_length]) model_inputs["attention_mask"][i] = torch.tensor(model_inputs["attention_mask"][i][:max_length]) labels["input_ids"][i] = torch.tensor(labels["input_ids"][i][:max_length]) model_inputs["labels"] = labels["input_ids"] return model_inputs def test_preprocess_function(examples): batch_size = len(examples[text_column]) inputs = [f"{text_column} : {x} Label : " for x in examples[text_column]] model_inputs = tokenizer(inputs) # print(model_inputs) for i in range(batch_size): sample_input_ids = model_inputs["input_ids"][i] model_inputs["input_ids"][i] = [tokenizer.pad_token_id] * ( max_length - len(sample_input_ids) ) + sample_input_ids model_inputs["attention_mask"][i] = [0] * (max_length - len(sample_input_ids)) + model_inputs[ "attention_mask" ][i] model_inputs["input_ids"][i] = torch.tensor(model_inputs["input_ids"][i][:max_length]) model_inputs["attention_mask"][i] = torch.tensor(model_inputs["attention_mask"][i][:max_length]) return model_inputs with accelerator.main_process_first(): processed_datasets = dataset.map( preprocess_function, batched=True, num_proc=1, remove_columns=dataset["train"].column_names, load_from_cache_file=True, desc="Running tokenizer on dataset", ) accelerator.wait_for_everyone() train_dataset = processed_datasets["train"] print("train_dataset", train_dataset) with accelerator.main_process_first(): processed_datasets = dataset.map( test_preprocess_function, batched=True, num_proc=1, remove_columns=dataset["train"].column_names, load_from_cache_file=False, desc="Running tokenizer on dataset", ) eval_dataset = processed_datasets["train"] test_dataset = processed_datasets["test"] train_dataloader = DataLoader( train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True ) eval_dataloader = DataLoader( eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True ) test_dataloader = DataLoader( test_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True ) print(next(iter(train_dataloader))) # creating model model = AutoModelForCausalLM.from_pretrained(model_name_or_path) model = get_peft_model(model, peft_config) model.print_trainable_parameters() # optimizer optimizer = torch.optim.AdamW(model.parameters(), lr=lr) # lr scheduler lr_scheduler = get_linear_schedule_with_warmup( optimizer=optimizer, num_warmup_steps=0, num_training_steps=(len(train_dataloader) * num_epochs), ) model, train_dataloader, eval_dataloader, test_dataloader, optimizer, lr_scheduler = accelerator.prepare( model, train_dataloader, eval_dataloader, test_dataloader, optimizer, lr_scheduler ) accelerator.print(model) is_ds_zero_3 = False if getattr(accelerator.state, "deepspeed_plugin", None): is_ds_zero_3 = accelerator.state.deepspeed_plugin.zero_stage == 3 latencies = [] for epoch in range(num_epochs): with TorchTracemalloc() as tracemalloc: model.train() total_loss = 0 for step, batch in enumerate(tqdm(train_dataloader)): accelerator.print(f"batch {batch['input_ids'].shape}") # Just a hack to dispatch on SDPA FA backend, by dropping the mask. batch["attention_mask"][:] = 1 torch.cuda.synchronize() start = time.perf_counter() with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False): outputs = model(**batch) loss = outputs.loss total_loss += loss.detach().float() accelerator.backward(loss) torch.cuda.synchronize() end = time.perf_counter() accelerator.print(f"Forward + backward: {end - start:.4f} s") if step != 0: latencies.append(end - start) optimizer.step() lr_scheduler.step() optimizer.zero_grad() # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage accelerator.print(f"Median forward + backward latency: {np.mean(latencies)} s") accelerator.print(f"GPU Memory before entering the train : {b2mb(tracemalloc.begin)}") accelerator.print(f"GPU Memory consumed at the end of the train (end-begin): {tracemalloc.used}") accelerator.print(f"GPU Peak Memory consumed during the train (max-begin): {tracemalloc.peaked}") accelerator.print( f"GPU Total Peak Memory consumed during the train (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}" ) accelerator.print(f"CPU Memory before entering the train : {b2mb(tracemalloc.cpu_begin)}") accelerator.print(f"CPU Memory consumed at the end of the train (end-begin): {tracemalloc.cpu_used}") accelerator.print(f"CPU Peak Memory consumed during the train (max-begin): {tracemalloc.cpu_peaked}") accelerator.print( f"CPU Total Peak Memory consumed during the train (max): {tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin)}" ) train_epoch_loss = total_loss / len(train_dataloader) train_ppl = torch.exp(train_epoch_loss) accelerator.print(f"{epoch=}: {train_ppl=} {train_epoch_loss=}") model.eval() eval_preds = [] with TorchTracemalloc() as tracemalloc: for _, batch in enumerate(tqdm(eval_dataloader)): batch = {k: v for k, v in batch.items() if k != "labels"} with torch.no_grad(): outputs = accelerator.unwrap_model(model).generate( **batch, synced_gpus=True, max_new_tokens=10 ) # synced_gpus=True for DS-stage 3 outputs = accelerator.pad_across_processes(outputs, dim=1, pad_index=tokenizer.pad_token_id) preds = accelerator.gather_for_metrics(outputs) preds = preds[:, max_length:].detach().cpu().numpy() eval_preds.extend(tokenizer.batch_decode(preds, skip_special_tokens=True)) # Printing the GPU memory usage details such as allocated memory, peak memory, and total memory usage accelerator.print(f"GPU Memory before entering the eval : {b2mb(tracemalloc.begin)}") accelerator.print(f"GPU Memory consumed at the end of the eval (end-begin): {tracemalloc.used}") accelerator.print(f"GPU Peak Memory consumed during the eval (max-begin): {tracemalloc.peaked}") accelerator.print( f"GPU Total Peak Memory consumed during the eval (max): {tracemalloc.peaked + b2mb(tracemalloc.begin)}" ) accelerator.print(f"CPU Memory before entering the eval : {b2mb(tracemalloc.cpu_begin)}") accelerator.print(f"CPU Memory consumed at the end of the eval (end-begin): {tracemalloc.cpu_used}") accelerator.print(f"CPU Peak Memory consumed during the eval (max-begin): {tracemalloc.cpu_peaked}") accelerator.print( f"CPU Total Peak Memory consumed during the eval (max): {tracemalloc.cpu_peaked + b2mb(tracemalloc.cpu_begin)}" ) correct = 0 total = 0 assert len(eval_preds) == len( dataset["train"][label_column] ), f"{len(eval_preds)} != {len(dataset['train'][label_column])}" for pred, true in zip(eval_preds, dataset["train"][label_column]): if pred.strip() == true.strip(): correct += 1 total += 1 accuracy = correct / total * 100 accelerator.print(f"{accuracy=}") accelerator.print(f"{eval_preds[:10]=}") accelerator.print(f"{dataset['train'][label_column][:10]=}") if do_test: model.eval() test_preds = [] for _, batch in enumerate(tqdm(test_dataloader)): batch = {k: v for k, v in batch.items() if k != "labels"} with torch.no_grad(): outputs = accelerator.unwrap_model(model).generate( **batch, synced_gpus=False, max_new_tokens=10 ) # synced_gpus=True for DS-stage 3 outputs = accelerator.pad_across_processes(outputs, dim=1, pad_index=tokenizer.pad_token_id) preds = accelerator.gather(outputs) preds = preds[:, max_length:].detach().cpu().numpy() test_preds.extend(tokenizer.batch_decode(preds, skip_special_tokens=True)) test_preds_cleaned = [] for _, pred in enumerate(test_preds): test_preds_cleaned.append(get_closest_label(pred, classes)) test_df = dataset["test"].to_pandas() assert len(test_preds_cleaned) == len(test_df), f"{len(test_preds_cleaned)} != {len(test_df)}" test_df[label_column] = test_preds_cleaned test_df["text_labels_orig"] = test_preds accelerator.print(test_df[[text_column, label_column]].sample(20)) pred_df = test_df[["ID", label_column]] pred_df.columns = ["ID", "Label"] os.makedirs(f"data/{dataset_name}", exist_ok=True) pred_df.to_csv(f"data/{dataset_name}/predictions.csv", index=False) accelerator.wait_for_everyone() # Option1: Pushing the model to Hugging Face Hub # model.push_to_hub( # f"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}".replace("/", "_"), # token = "hf_..." # ) # token (`bool` or `str`, *optional*): # `token` is to be used for HTTP Bearer authorization when accessing remote files. If `True`, will use the token generated # when running `huggingface-cli login` (stored in `~/.huggingface`). Will default to `True` if `repo_url` # is not specified. # Or you can get your token from https://huggingface.co/settings/token # Option2: Saving the model locally peft_model_id = f"{dataset_name}_{model_name_or_path}_{peft_config.peft_type}_{peft_config.task_type}".replace( "/", "_" ) model.save_pretrained(peft_model_id) accelerator.wait_for_everyone() if __name__ == "__main__": main()