notebooks/ner_finetuning_distilbert_lora_v3.ipynb (621 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"id": "22bbbb75-8c65-440e-a059-c63c2fa91996",
"metadata": {},
"source": [
"purpose of this notebook is to finetune the \"distilbert/distilbert-base-uncased\" model\n",
"Handles city, state and city-state separately"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e7436463-26a4-4ebb-abd1-771ee134220b",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import pandas as pd\n",
"from datasets import Dataset, DatasetDict\n",
"import evaluate\n",
"import numpy as np\n",
"from transformers import AutoTokenizer\n",
"from transformers import AutoModelForTokenClassification, TrainingArguments, Trainer\n",
"from peft import get_peft_model, LoraConfig, TaskType\n",
"from transformers import TrainingArguments, Trainer"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6740344d-7d09-457a-bb68-a64f2b532103",
"metadata": {},
"outputs": [],
"source": [
"from torch import cuda\n",
"device = 'cuda' if cuda.is_available() else 'mps'\n",
"print(device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "f6093b56-0180-4b67-8b04-b562979979ce",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# full_dataset = Dataset.from_parquet(\"data/combined_ner_examples.parquet\")\n",
"# full_dataset = Dataset.from_parquet(\"data/combined_ner_examples_v2.parquet\")\n",
"# full_dataset = Dataset.from_parquet(\"data/combined_ner_examples_v3.parquet\")\n",
"full_dataset = Dataset.from_parquet(\"data/synthetic_loc_dataset.parquet\")\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0f06a123-bdd7-4655-ae16-92bdc655924f",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"full_dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4e5b0e6e-e169-4e32-aeae-00c58949ff32",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"val_set_size = 5000\n",
"val_start = len(full_dataset) - val_set_size\n",
"val_start"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dd3d3524-3a60-4ce2-863b-08a6dadd2fcc",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"\n",
"\n",
"# Split the dataset into train and validation sets\n",
"train_dataset = full_dataset.select(range(val_start)) # Select training rows\n",
"val_dataset = full_dataset.select(range(val_start, len(full_dataset))) # Select last 1000 rows for validation\n",
"\n",
"# Combine them into a DatasetDict\n",
"dataset = DatasetDict({\n",
" 'train': train_dataset,\n",
" 'validation': val_dataset\n",
"})"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c4ffc3cd-886d-4f7f-a8c8-8f5413628646",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "340175ec-d224-4c3b-aaa5-16037c61fff0",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"dataset['train'].to_pandas()['tokens'].apply(len).hist(bins=20);"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d2046667-d9ef-44ab-a444-dc95752478aa",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f20d0f6-fa8e-47d7-a60b-2e673e25685a",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from transformers import AutoTokenizer\n",
"\n",
"# Load the tokenizer for distilbert-based NER\n",
"tokenizer = AutoTokenizer.from_pretrained(\"distilbert/distilbert-base-uncased\")\n",
"\n",
"# Function to tokenize the input and align labels with tokens\n",
"def tokenize_and_align_labels(example):\n",
" # Tokenize 'tokens' while keeping track of word boundaries\n",
" tokenized_inputs = tokenizer(\n",
" example['tokens'], \n",
" is_split_into_words=True, \n",
" truncation=True, \n",
" padding='max_length',\n",
" max_length=64,\n",
" )\n",
" \n",
" # Get the word_ids (mapping from tokens to original words)\n",
" word_ids = tokenized_inputs.word_ids()\n",
" aligned_labels = []\n",
"\n",
" previous_word_idx = None\n",
" for word_idx in word_ids:\n",
" if word_idx is None:\n",
" aligned_labels.append(-100) # Special tokens ([CLS], [SEP], etc.)\n",
" elif word_idx != previous_word_idx:\n",
" aligned_labels.append(example['ner_tags'][word_idx]) # Assign the label to the first token of each word\n",
" else:\n",
" aligned_labels.append(-100) # Subword tokens get label -100\n",
"\n",
" previous_word_idx = word_idx\n",
"\n",
" tokenized_inputs[\"labels\"] = aligned_labels\n",
" return tokenized_inputs\n",
"\n",
"# Apply the function to the dataset\n",
"tokenized_dataset = dataset.map(tokenize_and_align_labels)\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "106935d0-9093-4ff9-8e0f-afe4aa2d792c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"tokenized_dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "331d6dca-55d0-473e-94ca-90a152a9e9b7",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"tokenized_dataset['validation'][0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1221fd4e-3f16-4ed4-9d9f-c359604545ee",
"metadata": {},
"outputs": [],
"source": [
"def postprocess_predictions_and_labels(predictions, references):\n",
" true_predictions = []\n",
" true_labels = []\n",
" cmp_count = 0\n",
"\n",
" for prediction, reference in zip(predictions, references):\n",
" # Only keep labels that are not -100\n",
" true_labels_example = [label for label in reference if label != -100]\n",
" \n",
" # Align predictions: Remove predictions for which the corresponding reference label is -100\n",
" true_predictions_example = [pred for pred, ref in zip(prediction, reference) if ref != -100]\n",
"\n",
" # Ensure the length of predictions and labels matches\n",
" if len(true_predictions_example) == len(true_labels_example):\n",
" true_labels.append(true_labels_example)\n",
" true_predictions.append(true_predictions_example)\n",
" cmp_count += 1\n",
" else:\n",
" # Log or handle the error (example-level mismatch)\n",
" # print(f\"Skipping example due to mismatch: predictions ({len(true_predictions_example)}), labels ({len(true_labels_example)})\")\n",
" continue # Skip this example\n",
"\n",
" # Flatten the lists (convert from list of lists to a single list)\n",
" true_predictions = [pred for sublist in true_predictions for pred in sublist]\n",
" true_labels = [label for sublist in true_labels for label in sublist]\n",
" print(f\"cmp_count = {cmp_count} out of {len(predictions)}\")\n",
"\n",
" return true_predictions, true_labels\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6f62dee2-077d-451f-af48-60eaf50e5edd",
"metadata": {},
"outputs": [],
"source": [
"def compute_metrics(p):\n",
" logits, labels = p\n",
" predictions = np.argmax(logits, axis=1)\n",
" \n",
" # Post-process the predictions and labels to remove -100 values\n",
" true_predictions, true_labels = postprocess_predictions_and_labels(predictions, labels)\n",
"\n",
" # Combine metrics\n",
" accuracy_metric = evaluate.load(\"accuracy\")\n",
" precision_metric = evaluate.load(\"precision\")\n",
" recall_metric = evaluate.load(\"recall\")\n",
" f1_metric = evaluate.load(\"f1\")\n",
"\n",
" # Calculate metrics\n",
" accuracy = accuracy_metric.compute(predictions=true_predictions, references=true_labels)\n",
" precision = precision_metric.compute(predictions=true_predictions, references=true_labels, average=\"weighted\")\n",
" recall = recall_metric.compute(predictions=true_predictions, references=true_labels, average=\"weighted\")\n",
" f1 = f1_metric.compute(predictions=true_predictions, references=true_labels, average=\"weighted\")\n",
"\n",
" return {\n",
" \"accuracy\": accuracy[\"accuracy\"],\n",
" \"precision\": precision[\"precision\"],\n",
" \"recall\": recall[\"recall\"],\n",
" \"f1\": f1[\"f1\"]\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f7acd16-763a-4055-b492-3007b5057da1",
"metadata": {},
"outputs": [],
"source": [
"# Define the NER label mappings\n",
"# id2label = {\n",
"# 0: \"O\", # Outside any entity\n",
"# 1: \"B-PER\", # Beginning of a person entity\n",
"# 2: \"I-PER\", # Inside a person entity\n",
"# 3: \"B-ORG\", # Beginning of an organization entity\n",
"# 4: \"I-ORG\", # Inside an organization entity\n",
"# 5: \"B-LOC\", # Beginning of a location entity\n",
"# 6: \"I-LOC\", # Inside a location entity\n",
"# 7: \"B-MISC\", # Beginning of a miscellaneous entity\n",
"# 8: \"I-MISC\" # Inside a miscellaneous entity\n",
"# }\n",
"\n",
"id2label = {\n",
" 0: \"O\", # Outside any named entity\n",
" 1: \"B-PER\", # Beginning of a person entity\n",
" 2: \"I-PER\", # Inside a person entity\n",
" 3: \"B-ORG\", # Beginning of an organization entity\n",
" 4: \"I-ORG\", # Inside an organization entity\n",
" 5: \"B-CITY\", # Beginning of a city entity\n",
" 6: \"I-CITY\", # Inside a city entity\n",
" 7: \"B-STATE\", # Beginning of a state entity\n",
" 8: \"I-STATE\", # Inside a state entity\n",
" 9: \"B-CITYSTATE\", # Beginning of a city_state entity\n",
" 10: \"I-CITYSTATE\", # Inside a city_state entity\n",
"}\n",
"\n",
"label2id = {v: k for k, v in id2label.items()}\n",
"\n",
"# Load the pre-trained model\n",
"model = AutoModelForTokenClassification.from_pretrained(\"distilbert/distilbert-base-uncased\", \n",
" num_labels=11, \n",
" id2label=id2label, \n",
" label2id=label2id)\n",
"\n",
"# Define the LoRA configuration\n",
"lora_config = LoraConfig(\n",
" task_type=TaskType.TOKEN_CLS, # Task type is token classification (NER)\n",
" r=8, # Low-rank dimension (you can experiment with this)\n",
" lora_alpha=32, # Scaling factor for LoRA\n",
" lora_dropout=0.1, # Dropout rate for LoRA\n",
" target_modules=['q_lin'] # LoRA is applied to query layer\n",
")\n",
"\n",
"# Apply LoRA to the model\n",
"lora_model = get_peft_model(model, lora_config)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "343c2425-f8a2-4d13-a1cc-22ed6a4717a5",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# lora_model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8c185799-1fd7-4319-a2a5-89ba1ff52df1",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"\n",
"\n",
"# Define the training arguments\n",
"training_args = TrainingArguments(\n",
" output_dir=\"./results\", # Output directory\n",
" evaluation_strategy=\"epoch\", # Evaluate at the end of every epoch\n",
" learning_rate=2e-5, # Learning rate\n",
" per_device_train_batch_size=16, # Batch size for training\n",
" per_device_eval_batch_size=16, # Batch size for evaluation\n",
" num_train_epochs=6, # Number of training epochs\n",
" weight_decay=0.01, # Weight decay\n",
" logging_dir='./logs', # Directory for logging\n",
")\n",
"\n",
"# Initialize the Trainer\n",
"trainer = Trainer(\n",
" model=lora_model, # LoRA-wrapped model\n",
" args=training_args, # Training arguments\n",
" train_dataset=tokenized_dataset['train'], # Training dataset\n",
" eval_dataset=tokenized_dataset[\"validation\"], # Validation dataset (if available)\n",
" tokenizer=tokenizer, # Tokenizer\n",
" compute_metrics=compute_metrics, # model perfomance evaluation metric\n",
")\n",
"\n",
"# Fine-tune the model\n",
"trainer.train()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "82880e92-4d83-442a-9a1c-3a2386b1c942",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"import torch\n",
"from transformers import AutoTokenizer\n",
"\n",
"# Your text list\n",
"text_list = [\n",
" 'New York', 'Los Angeles', 'Chicago', 'Philadelphia', 'Dallas',\n",
" 'Fort Worth', 'Houston', 'Atlanta', 'Boston', 'Manchester',\n",
" 'Washington, D.C.', 'Hagerstown', 'San Francisco', 'Oakland',\n",
" 'San Jose', \n",
" # 'san jose',\n",
" 'weather in san jose',\n",
" 'weather in Boston',\n",
" 'Weather in Boston',\n",
" 'weather Boston',\n",
" 'Weather Boston',\n",
" 'weather',\n",
" 'Weather',\n",
" 'Boston weather',\n",
" 'Boston Weather',\n",
" # 'I love Pizzahut',\n",
" # 'I like Starbucks',\n",
" 'sushi restaurants in Sunnyvale, CA',\n",
" 'sushi restaurants in Sunnyvale, California',\n",
" 'ramen in sf',\n",
" 'sushi sf',\n",
" 'sushi sfo',\n",
" 'sushi sfo, CA',\n",
" 'ramen sfo',\n",
" 'sfo sushi'\n",
" 'phx ramen',\n",
"]\n",
"\n",
"model = trainer.model\n",
"\n",
"# Function to make predictions and group entities\n",
"def predict_ner(text_list):\n",
" model.eval() # Set the model to evaluation mode\n",
"\n",
" for text in text_list:\n",
" # Tokenize the input text\n",
" inputs = tokenizer(text, return_tensors=\"pt\", truncation=True, padding=True)\n",
" \n",
" # Move inputs to the same device as the model\n",
" inputs = {k: v.to(model.device) for k, v in inputs.items()}\n",
" \n",
" # Perform inference\n",
" with torch.no_grad():\n",
" outputs = model(**inputs)\n",
" \n",
" # Get predictions (logits -> predicted labels)\n",
" predictions = torch.argmax(outputs.logits, dim=-1).cpu().numpy()[0]\n",
" \n",
" # Map the predictions to labels and tokens\n",
" tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0].cpu().numpy())\n",
" ner_labels = [model.config.id2label[pred] for pred in predictions]\n",
"\n",
" # Group tokens back into entities\n",
" current_entity = []\n",
" current_label = None\n",
" entities = []\n",
"\n",
" for token, label in zip(tokens, ner_labels):\n",
" print(token, label)\n",
" # Ignore special tokens like [CLS], [SEP]\n",
" if token in [\"[CLS]\", \"[SEP]\"]:\n",
" continue\n",
" # Handle subword tokens (tokens starting with ##)\n",
" if token.startswith(\"##\"):\n",
" if current_entity:\n",
" current_entity[-1] += token[2:] # Append the subword without \"##\"\n",
" elif label.startswith(\"B-\") or (label.startswith(\"I-\") and label != current_label):\n",
" # New entity starts, append the old one\n",
" if current_entity:\n",
" entities.append(\" \".join(current_entity))\n",
" current_entity = []\n",
" current_entity.append(token)\n",
" current_label = label\n",
" elif label.startswith(\"I-\") and label == current_label:\n",
" # Continue current entity\n",
" current_entity.append(token)\n",
" else:\n",
" # Non-entity token or 'O'\n",
" if current_entity:\n",
" entities.append(\" \".join(current_entity))\n",
" current_entity = []\n",
" current_label = None\n",
"\n",
" # Append any remaining entity\n",
" if current_entity:\n",
" entities.append(\" \".join(current_entity))\n",
"\n",
" # Clean up tokens (remove subword tokens and punctuation issues, etc.)\n",
" clean_entities = []\n",
" for entity in entities:\n",
" entity = entity.replace(\" ##\", \" \")\n",
" entity = entity.replace(\" .\", \".\") # Handle punctuation\n",
" entity = entity.replace(\" ,\", \",\")\n",
" clean_entities.append(entity)\n",
"\n",
" # Print the result for comparison\n",
" print(f\"Input: {text}\")\n",
" print(f\"Predicted entities: {' '.join(clean_entities)}\")\n",
" print()\n",
"\n",
"# Run predictions on the text list\n",
"predict_ner(text_list)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fdf1c664-80e1-47a1-a33f-98e2dd509623",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from transformers import AutoModelForTokenClassification, AutoTokenizer\n",
"from peft import PeftModel, PeftConfig\n",
"\n",
"# Load the base model (DistilBERT NER model)\n",
"base_model = AutoModelForTokenClassification.from_pretrained(\"distilbert/distilbert-base-uncased\",\n",
" num_labels=11,\n",
" id2label=id2label,\n",
" label2id=label2id)\n",
"\n",
"# Load the tokenizer\n",
"tokenizer = AutoTokenizer.from_pretrained(\"distilbert/distilbert-base-uncased\")\n",
"\n",
"# Load the LoRA-adapted model\n",
"peft_config = PeftConfig.from_pretrained(\"results/checkpoint-73128\")\n",
"lora_model = PeftModel.from_pretrained(base_model, \"results/checkpoint-73128\")\n",
"\n",
"# Merge the LoRA weights with the base model\n",
"merged_model = lora_model.merge_and_unload() # This merges LoRA into the base model\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d9d63ba4-2659-44b6-bf40-e33cc1516545",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# Save the merged model and tokenizer\n",
"save_dir = \"tmp/merged_distilbert_uncased_ner\"\n",
"merged_model.save_pretrained(save_dir)\n",
"tokenizer.save_pretrained(save_dir)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "efae28df-7596-4142-9aa8-9e8e5d291f9c",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# !huggingface-cli whoami"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b73fbf99-9165-4afe-a34a-a7a112427371",
"metadata": {},
"outputs": [],
"source": [
"# !huggingface-cli login"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0c889aa6-4b3d-479b-8dc7-23abf7a3164e",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"# Upload the merged model\n",
"merged_model_dir = \"tmp/merged_distilbert_uncased_ner\"\n",
"merged_repo_id = \"Mozilla/distilbert-uncased-NER-LoRA\" \n",
"\n",
"merged_model.push_to_hub(merged_repo_id)\n",
"tokenizer.push_to_hub(merged_repo_id)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "012f08c0-2488-4efb-93b2-2d89a8ff6cc4",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"environment": {
"kernel": "my_env",
"name": ".m124",
"type": "gcloud",
"uri": "us-docker.pkg.dev/deeplearning-platform-release/gcr.io/:m124"
},
"kernelspec": {
"display_name": "Python (my_env) (Local)",
"language": "python",
"name": "my_env"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.14"
}
},
"nbformat": 4,
"nbformat_minor": 5
}