notebooks/ner_finetuning_distilbert_lora_v3.ipynb (621 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "id": "22bbbb75-8c65-440e-a059-c63c2fa91996", "metadata": {}, "source": [ "purpose of this notebook is to finetune the \"distilbert/distilbert-base-uncased\" model\n", "Handles city, state and city-state separately" ] }, { "cell_type": "code", "execution_count": null, "id": "e7436463-26a4-4ebb-abd1-771ee134220b", "metadata": { "tags": [] }, "outputs": [], "source": [ "import pandas as pd\n", "from datasets import Dataset, DatasetDict\n", "import evaluate\n", "import numpy as np\n", "from transformers import AutoTokenizer\n", "from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer\n", "from peft import get_peft_model, LoraConfig, TaskType\n", "from transformers import TrainingArguments, Trainer" ] }, { "cell_type": "code", "execution_count": null, "id": "6740344d-7d09-457a-bb68-a64f2b532103", "metadata": {}, "outputs": [], "source": [ "from torch import cuda\n", "device = 'cuda' if cuda.is_available() else 'mps'\n", "print(device)" ] }, { "cell_type": "code", "execution_count": null, "id": "f6093b56-0180-4b67-8b04-b562979979ce", "metadata": { "tags": [] }, "outputs": [], "source": [ "# full_dataset = Dataset.from_parquet(\"data/combined_ner_examples.parquet\")\n", "# full_dataset = Dataset.from_parquet(\"data/combined_ner_examples_v2.parquet\")\n", "# full_dataset = Dataset.from_parquet(\"data/combined_ner_examples_v3.parquet\")\n", "full_dataset = Dataset.from_parquet(\"data/synthetic_loc_dataset.parquet\")\n" ] }, { "cell_type": "code", "execution_count": null, "id": "0f06a123-bdd7-4655-ae16-92bdc655924f", "metadata": { "tags": [] }, "outputs": [], "source": [ "full_dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "4e5b0e6e-e169-4e32-aeae-00c58949ff32", "metadata": { "tags": [] }, "outputs": [], "source": [ "val_set_size = 5000\n", "val_start = len(full_dataset) - val_set_size\n", "val_start" ] }, { "cell_type": "code", "execution_count": null, "id": "dd3d3524-3a60-4ce2-863b-08a6dadd2fcc", "metadata": { "tags": [] }, "outputs": [], "source": [ "\n", "\n", "# Split the dataset into train and validation sets\n", "train_dataset = full_dataset.select(range(val_start)) # Select training rows\n", "val_dataset = full_dataset.select(range(val_start, len(full_dataset))) # Select last 1000 rows for validation\n", "\n", "# Combine them into a DatasetDict\n", "dataset = DatasetDict({\n", " 'train': train_dataset,\n", " 'validation': val_dataset\n", "})" ] }, { "cell_type": "code", "execution_count": null, "id": "c4ffc3cd-886d-4f7f-a8c8-8f5413628646", "metadata": { "tags": [] }, "outputs": [], "source": [ "dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "340175ec-d224-4c3b-aaa5-16037c61fff0", "metadata": { "tags": [] }, "outputs": [], "source": [ "dataset['train'].to_pandas()['tokens'].apply(len).hist(bins=20);" ] }, { "cell_type": "code", "execution_count": null, "id": "d2046667-d9ef-44ab-a444-dc95752478aa", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "8f20d0f6-fa8e-47d7-a60b-2e673e25685a", "metadata": { "tags": [] }, "outputs": [], "source": [ "from transformers import AutoTokenizer\n", "\n", "# Load the tokenizer for distilbert-based NER\n", "tokenizer = AutoTokenizer.from_pretrained(\"distilbert/distilbert-base-uncased\")\n", "\n", "# Function to tokenize the input and align labels with tokens\n", "def tokenize_and_align_labels(example):\n", " # Tokenize 'tokens' while keeping track of word boundaries\n", " tokenized_inputs = tokenizer(\n", " example['tokens'], \n", " is_split_into_words=True, \n", " truncation=True, \n", " padding='max_length',\n", " max_length=64,\n", " )\n", " \n", " # Get the word_ids (mapping from tokens to original words)\n", " word_ids = tokenized_inputs.word_ids()\n", " aligned_labels = []\n", "\n", " previous_word_idx = None\n", " for word_idx in word_ids:\n", " if word_idx is None:\n", " aligned_labels.append(-100) # Special tokens ([CLS], [SEP], etc.)\n", " elif word_idx != previous_word_idx:\n", " aligned_labels.append(example['ner_tags'][word_idx]) # Assign the label to the first token of each word\n", " else:\n", " aligned_labels.append(-100) # Subword tokens get label -100\n", "\n", " previous_word_idx = word_idx\n", "\n", " tokenized_inputs[\"labels\"] = aligned_labels\n", " return tokenized_inputs\n", "\n", "# Apply the function to the dataset\n", "tokenized_dataset = dataset.map(tokenize_and_align_labels)\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "106935d0-9093-4ff9-8e0f-afe4aa2d792c", "metadata": { "tags": [] }, "outputs": [], "source": [ "tokenized_dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "331d6dca-55d0-473e-94ca-90a152a9e9b7", "metadata": { "tags": [] }, "outputs": [], "source": [ "tokenized_dataset['validation'][0]" ] }, { "cell_type": "code", "execution_count": null, "id": "1221fd4e-3f16-4ed4-9d9f-c359604545ee", "metadata": {}, "outputs": [], "source": [ "def postprocess_predictions_and_labels(predictions, references):\n", " true_predictions = []\n", " true_labels = []\n", " cmp_count = 0\n", "\n", " for prediction, reference in zip(predictions, references):\n", " # Only keep labels that are not -100\n", " true_labels_example = [label for label in reference if label != -100]\n", " \n", " # Align predictions: Remove predictions for which the corresponding reference label is -100\n", " true_predictions_example = [pred for pred, ref in zip(prediction, reference) if ref != -100]\n", "\n", " # Ensure the length of predictions and labels matches\n", " if len(true_predictions_example) == len(true_labels_example):\n", " true_labels.append(true_labels_example)\n", " true_predictions.append(true_predictions_example)\n", " cmp_count += 1\n", " else:\n", " # Log or handle the error (example-level mismatch)\n", " # print(f\"Skipping example due to mismatch: predictions ({len(true_predictions_example)}), labels ({len(true_labels_example)})\")\n", " continue # Skip this example\n", "\n", " # Flatten the lists (convert from list of lists to a single list)\n", " true_predictions = [pred for sublist in true_predictions for pred in sublist]\n", " true_labels = [label for sublist in true_labels for label in sublist]\n", " print(f\"cmp_count = {cmp_count} out of {len(predictions)}\")\n", "\n", " return true_predictions, true_labels\n" ] }, { "cell_type": "code", "execution_count": null, "id": "6f62dee2-077d-451f-af48-60eaf50e5edd", "metadata": {}, "outputs": [], "source": [ "def compute_metrics(p):\n", " logits, labels = p\n", " predictions = np.argmax(logits, axis=1)\n", " \n", " # Post-process the predictions and labels to remove -100 values\n", " true_predictions, true_labels = postprocess_predictions_and_labels(predictions, labels)\n", "\n", " # Combine metrics\n", " accuracy_metric = evaluate.load(\"accuracy\")\n", " precision_metric = evaluate.load(\"precision\")\n", " recall_metric = evaluate.load(\"recall\")\n", " f1_metric = evaluate.load(\"f1\")\n", "\n", " # Calculate metrics\n", " accuracy = accuracy_metric.compute(predictions=true_predictions, references=true_labels)\n", " precision = precision_metric.compute(predictions=true_predictions, references=true_labels, average=\"weighted\")\n", " recall = recall_metric.compute(predictions=true_predictions, references=true_labels, average=\"weighted\")\n", " f1 = f1_metric.compute(predictions=true_predictions, references=true_labels, average=\"weighted\")\n", "\n", " return {\n", " \"accuracy\": accuracy[\"accuracy\"],\n", " \"precision\": precision[\"precision\"],\n", " \"recall\": recall[\"recall\"],\n", " \"f1\": f1[\"f1\"]\n", " }" ] }, { "cell_type": "code", "execution_count": null, "id": "1f7acd16-763a-4055-b492-3007b5057da1", "metadata": {}, "outputs": [], "source": [ "# Define the NER label mappings\n", "# id2label = {\n", "# 0: \"O\", # Outside any entity\n", "# 1: \"B-PER\", # Beginning of a person entity\n", "# 2: \"I-PER\", # Inside a person entity\n", "# 3: \"B-ORG\", # Beginning of an organization entity\n", "# 4: \"I-ORG\", # Inside an organization entity\n", "# 5: \"B-LOC\", # Beginning of a location entity\n", "# 6: \"I-LOC\", # Inside a location entity\n", "# 7: \"B-MISC\", # Beginning of a miscellaneous entity\n", "# 8: \"I-MISC\" # Inside a miscellaneous entity\n", "# }\n", "\n", "id2label = {\n", " 0: \"O\", # Outside any named entity\n", " 1: \"B-PER\", # Beginning of a person entity\n", " 2: \"I-PER\", # Inside a person entity\n", " 3: \"B-ORG\", # Beginning of an organization entity\n", " 4: \"I-ORG\", # Inside an organization entity\n", " 5: \"B-CITY\", # Beginning of a city entity\n", " 6: \"I-CITY\", # Inside a city entity\n", " 7: \"B-STATE\", # Beginning of a state entity\n", " 8: \"I-STATE\", # Inside a state entity\n", " 9: \"B-CITYSTATE\", # Beginning of a city_state entity\n", " 10: \"I-CITYSTATE\", # Inside a city_state entity\n", "}\n", "\n", "label2id = {v: k for k, v in id2label.items()}\n", "\n", "# Load the pre-trained model\n", "model = AutoModelForTokenClassification.from_pretrained(\"distilbert/distilbert-base-uncased\", \n", " num_labels=11, \n", " id2label=id2label, \n", " label2id=label2id)\n", "\n", "# Define the LoRA configuration\n", "lora_config = LoraConfig(\n", " task_type=TaskType.TOKEN_CLS, # Task type is token classification (NER)\n", " r=8, # Low-rank dimension (you can experiment with this)\n", " lora_alpha=32, # Scaling factor for LoRA\n", " lora_dropout=0.1, # Dropout rate for LoRA\n", " target_modules=['q_lin'] # LoRA is applied to query layer\n", ")\n", "\n", "# Apply LoRA to the model\n", "lora_model = get_peft_model(model, lora_config)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "343c2425-f8a2-4d13-a1cc-22ed6a4717a5", "metadata": { "tags": [] }, "outputs": [], "source": [ "# lora_model" ] }, { "cell_type": "code", "execution_count": null, "id": "8c185799-1fd7-4319-a2a5-89ba1ff52df1", "metadata": { "tags": [] }, "outputs": [], "source": [ "\n", "\n", "# Define the training arguments\n", "training_args = TrainingArguments(\n", " output_dir=\"./results\", # Output directory\n", " evaluation_strategy=\"epoch\", # Evaluate at the end of every epoch\n", " learning_rate=2e-5, # Learning rate\n", " per_device_train_batch_size=16, # Batch size for training\n", " per_device_eval_batch_size=16, # Batch size for evaluation\n", " num_train_epochs=6, # Number of training epochs\n", " weight_decay=0.01, # Weight decay\n", " logging_dir='./logs', # Directory for logging\n", ")\n", "\n", "# Initialize the Trainer\n", "trainer = Trainer(\n", " model=lora_model, # LoRA-wrapped model\n", " args=training_args, # Training arguments\n", " train_dataset=tokenized_dataset['train'], # Training dataset\n", " eval_dataset=tokenized_dataset[\"validation\"], # Validation dataset (if available)\n", " tokenizer=tokenizer, # Tokenizer\n", " compute_metrics=compute_metrics, # model perfomance evaluation metric\n", ")\n", "\n", "# Fine-tune the model\n", "trainer.train()\n" ] }, { "cell_type": "code", "execution_count": null, "id": "82880e92-4d83-442a-9a1c-3a2386b1c942", "metadata": { "tags": [] }, "outputs": [], "source": [ "import torch\n", "from transformers import AutoTokenizer\n", "\n", "# Your text list\n", "text_list = [\n", " 'New York', 'Los Angeles', 'Chicago', 'Philadelphia', 'Dallas',\n", " 'Fort Worth', 'Houston', 'Atlanta', 'Boston', 'Manchester',\n", " 'Washington, D.C.', 'Hagerstown', 'San Francisco', 'Oakland',\n", " 'San Jose', \n", " # 'san jose',\n", " 'weather in san jose',\n", " 'weather in Boston',\n", " 'Weather in Boston',\n", " 'weather Boston',\n", " 'Weather Boston',\n", " 'weather',\n", " 'Weather',\n", " 'Boston weather',\n", " 'Boston Weather',\n", " # 'I love Pizzahut',\n", " # 'I like Starbucks',\n", " 'sushi restaurants in Sunnyvale, CA',\n", " 'sushi restaurants in Sunnyvale, California',\n", " 'ramen in sf',\n", " 'sushi sf',\n", " 'sushi sfo',\n", " 'sushi sfo, CA',\n", " 'ramen sfo',\n", " 'sfo sushi'\n", " 'phx ramen',\n", "]\n", "\n", "model = trainer.model\n", "\n", "# Function to make predictions and group entities\n", "def predict_ner(text_list):\n", " model.eval() # Set the model to evaluation mode\n", "\n", " for text in text_list:\n", " # Tokenize the input text\n", " inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, padding=True)\n", " \n", " # Move inputs to the same device as the model\n", " inputs = {k: v.to(model.device) for k, v in inputs.items()}\n", " \n", " # Perform inference\n", " with torch.no_grad():\n", " outputs = model(**inputs)\n", " \n", " # Get predictions (logits -> predicted labels)\n", " predictions = torch.argmax(outputs.logits, dim=-1).cpu().numpy()[0]\n", " \n", " # Map the predictions to labels and tokens\n", " tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].cpu().numpy())\n", " ner_labels = [model.config.id2label[pred] for pred in predictions]\n", "\n", " # Group tokens back into entities\n", " current_entity = []\n", " current_label = None\n", " entities = []\n", "\n", " for token, label in zip(tokens, ner_labels):\n", " print(token, label)\n", " # Ignore special tokens like [CLS], [SEP]\n", " if token in [\"[CLS]\", \"[SEP]\"]:\n", " continue\n", " # Handle subword tokens (tokens starting with ##)\n", " if token.startswith(\"##\"):\n", " if current_entity:\n", " current_entity[-1] += token[2:] # Append the subword without \"##\"\n", " elif label.startswith(\"B-\") or (label.startswith(\"I-\") and label != current_label):\n", " # New entity starts, append the old one\n", " if current_entity:\n", " entities.append(\" \".join(current_entity))\n", " current_entity = []\n", " current_entity.append(token)\n", " current_label = label\n", " elif label.startswith(\"I-\") and label == current_label:\n", " # Continue current entity\n", " current_entity.append(token)\n", " else:\n", " # Non-entity token or 'O'\n", " if current_entity:\n", " entities.append(\" \".join(current_entity))\n", " current_entity = []\n", " current_label = None\n", "\n", " # Append any remaining entity\n", " if current_entity:\n", " entities.append(\" \".join(current_entity))\n", "\n", " # Clean up tokens (remove subword tokens and punctuation issues, etc.)\n", " clean_entities = []\n", " for entity in entities:\n", " entity = entity.replace(\" ##\", \" \")\n", " entity = entity.replace(\" .\", \".\") # Handle punctuation\n", " entity = entity.replace(\" ,\", \",\")\n", " clean_entities.append(entity)\n", "\n", " # Print the result for comparison\n", " print(f\"Input: {text}\")\n", " print(f\"Predicted entities: {' '.join(clean_entities)}\")\n", " print()\n", "\n", "# Run predictions on the text list\n", "predict_ner(text_list)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "fdf1c664-80e1-47a1-a33f-98e2dd509623", "metadata": { "tags": [] }, "outputs": [], "source": [ "from transformers import AutoModelForTokenClassification, AutoTokenizer\n", "from peft import PeftModel, PeftConfig\n", "\n", "# Load the base model (DistilBERT NER model)\n", "base_model = AutoModelForTokenClassification.from_pretrained(\"distilbert/distilbert-base-uncased\",\n", " num_labels=11,\n", " id2label=id2label,\n", " label2id=label2id)\n", "\n", "# Load the tokenizer\n", "tokenizer = AutoTokenizer.from_pretrained(\"distilbert/distilbert-base-uncased\")\n", "\n", "# Load the LoRA-adapted model\n", "peft_config = PeftConfig.from_pretrained(\"results/checkpoint-73128\")\n", "lora_model = PeftModel.from_pretrained(base_model, \"results/checkpoint-73128\")\n", "\n", "# Merge the LoRA weights with the base model\n", "merged_model = lora_model.merge_and_unload() # This merges LoRA into the base model\n" ] }, { "cell_type": "code", "execution_count": null, "id": "d9d63ba4-2659-44b6-bf40-e33cc1516545", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Save the merged model and tokenizer\n", "save_dir = \"tmp/merged_distilbert_uncased_ner\"\n", "merged_model.save_pretrained(save_dir)\n", "tokenizer.save_pretrained(save_dir)" ] }, { "cell_type": "code", "execution_count": null, "id": "efae28df-7596-4142-9aa8-9e8e5d291f9c", "metadata": { "tags": [] }, "outputs": [], "source": [ "# !huggingface-cli whoami" ] }, { "cell_type": "code", "execution_count": null, "id": "b73fbf99-9165-4afe-a34a-a7a112427371", "metadata": {}, "outputs": [], "source": [ "# !huggingface-cli login" ] }, { "cell_type": "code", "execution_count": null, "id": "0c889aa6-4b3d-479b-8dc7-23abf7a3164e", "metadata": { "tags": [] }, "outputs": [], "source": [ "# Upload the merged model\n", "merged_model_dir = \"tmp/merged_distilbert_uncased_ner\"\n", "merged_repo_id = \"Mozilla/distilbert-uncased-NER-LoRA\" \n", "\n", "merged_model.push_to_hub(merged_repo_id)\n", "tokenizer.push_to_hub(merged_repo_id)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "012f08c0-2488-4efb-93b2-2d89a8ff6cc4", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "environment": { "kernel": "my_env", "name": ".m124", "type": "gcloud", "uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/:m124" }, "kernelspec": { "display_name": "Python (my_env) (Local)", "language": "python", "name": "my_env" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 5 }