notebooks/city_weather_exploration_and_dataprep.ipynb (1,024 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "id": "73d1863d-1d54-4cdd-843c-c033b28f15f6", "metadata": {}, "source": [ "Explore whether the weather keywords and locations are captured correctly" ] }, { "cell_type": "code", "execution_count": null, "id": "bd4805cc-8d46-40fa-8d39-35158d9212d4", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import requests\n", "from bs4 import BeautifulSoup\n", "import re" ] }, { "cell_type": "code", "execution_count": null, "id": "b64db933-17ab-47cc-b0ba-ae37e89e450a", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import random" ] }, { "cell_type": "code", "execution_count": null, "id": "d70bc639-b4de-4544-bd0f-f18a0b263a66", "metadata": {}, "outputs": [], "source": [ "url = \"https://en.m.wikipedia.org/wiki/List_of_television_stations_in_North_America_by_media_market\"\n", "response = requests.get(url)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a1c49059-f982-46a5-a871-aeb2ec2a6688", "metadata": { "scrolled": true }, "outputs": [], "source": [ "if response.status_code == 200:\n", " soup = BeautifulSoup(response.content, 'html.parser')\n", " dma_heading = soup.find('h4', string='DMAs')\n", " dma_list = dma_heading.find_next('ul')\n", " \n", " dma_data = []\n", " if dma_list:\n", " for li in dma_list.find_all('li'):\n", " market_name = li.get_text(strip=True)\n", "\n", " # Split by dash (-) or en-dash (–) to handle cases like \"Dallas-Fort Worth\"\n", " split_names = re.split(r'–|-', market_name)\n", "\n", " # Process each split name\n", " for name in split_names:\n", " # Remove the (#NUM) part using regex\n", " name = re.sub(r'\\s*\\(#\\d+\\)', '', name).strip()\n", "\n", " # Check if there's a city in parentheses and split them\n", " match = re.match(r'(.+?)\\s*\\((.+?)\\)', name)\n", " if match:\n", " main_city = match.group(1).strip()\n", " parenthetical_city = match.group(2).strip()\n", " dma_data.append(main_city) # Add the main city\n", " dma_data.append(parenthetical_city) # Add the city in parentheses\n", " else:\n", " dma_data.append(name) \n", "\n", " for index, dma in enumerate(dma_data, start=1):\n", " print(f\"{index}. {dma}\")\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "52f9543e-4b78-46f1-828f-8f49340a4be0", "metadata": {}, "outputs": [], "source": [ "dma_data[:5]" ] }, { "cell_type": "markdown", "id": "8bcf91d7-8344-4b5e-9641-461b2630cb0f", "metadata": {}, "source": [ "#### Read the data/geonames-cities-states.json" ] }, { "cell_type": "code", "execution_count": null, "id": "738661a5-668f-4b2c-8823-dc3c0c92be94", "metadata": {}, "outputs": [], "source": [ "import json \n", "\n", "def get_geonames_city_state_data():\n", " geonames_file = \"../data/geonames-cities-states.json\"\n", " with open(geonames_file, 'r') as f:\n", " geonames_dict = json.load(f)\n", " \n", " \n", " cities_data = pd.DataFrame(geonames_dict['cities'])\\\n", " .rename(columns={'admin1_code': 'state_code', 'name': 'city_name', 'population': 'city_popln'})\n", " cities_data = cities_data[['id', 'state_code', 'city_name', 'city_popln', 'alternate_names']]\n", " states_data = pd.DataFrame(geonames_dict['states_by_abbr'].values())\\\n", " .rename(columns={'admin1_code': 'state_code', 'name': 'state_name'})\n", " states_data = states_data[['state_code', 'state_name']]\n", " city_states_data = cities_data.merge(states_data, how='left', on='state_code')\n", " city_states_data['city_weight'] = city_states_data['city_popln'] / city_states_data['city_popln'].sum()\n", " return city_states_data\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "a3aeb4bd-2e84-4121-84b7-8ffb1118ca37", "metadata": {}, "outputs": [], "source": [ "city_states_data = get_geonames_city_state_data()\n", "print(len(city_states_data))\n", "city_states_data" ] }, { "cell_type": "code", "execution_count": null, "id": "d35076ae-1d45-4699-8257-e98612500e43", "metadata": {}, "outputs": [], "source": [ "city_states_data.sort_values('city_weight', ascending=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "df043822-779c-4f9c-89eb-b331e2b0de19", "metadata": {}, "outputs": [], "source": [ "# useful for post processing to standardize the city names\n", "def build_lookup(dataframe):\n", " # Initialize an empty dictionary for the lookup\n", " lookup = {}\n", " \n", " # Iterate over each row in the DataFrame\n", " for index, row in dataframe.iterrows():\n", " city_name = row['city_name']\n", " alternate_names = row['alternate_names']\n", " \n", " # Iterate over the list of alternate names and map them to the city_name\n", " for alt_name in alternate_names:\n", " lookup[alt_name.lower()] = city_name # Convert alternate names to lowercase for consistency\n", " \n", " return lookup\n", "\n", "city_alternate_to_city_lkp = build_lookup(city_states_data)" ] }, { "cell_type": "code", "execution_count": null, "id": "62a392e3-e18e-470f-9f95-ad35ebaebca8", "metadata": {}, "outputs": [], "source": [ "len(city_alternate_to_city_lkp)" ] }, { "cell_type": "code", "execution_count": null, "id": "1f9d1453-8955-4dec-b0e2-4e8d29a82046", "metadata": {}, "outputs": [], "source": [ "city_states_data['alternate_names'].apply(len).value_counts()" ] }, { "cell_type": "code", "execution_count": null, "id": "dc9cc8fe-95ad-45e1-8bd8-134faf7aa37d", "metadata": {}, "outputs": [], "source": [ "np.random.seed(42)\n", "\n", "def get_alternate_or_actual_name(row):\n", " if row['alternate_names'] and isinstance(row['alternate_names'], list):\n", " return random.choice(row['alternate_names'])\n", " return row['city_name']\n", "\n", "def combine_city_with_states(row):\n", " if row['state_code'] is not None:\n", " # return row['city'] + \", \" + row['state_code']\n", " return row['city'] + \", \" + random.choice([row['state_code'], row['state_name']])\n", " return row['city']\n", " \n", "def sample_location(df, n_examples=10000, state_ratio=0.5):\n", " weights = df['city_weight']\n", " samples = df[['id', 'city_name', 'alternate_names', 'state_code', 'state_name', 'city_popln']].sample(n=n_examples, weights=weights, replace=True)\n", " states_idx = np.random.random(n_examples) <= state_ratio\n", " samples.loc[states_idx, 'state_code'] = None\n", " random_alternate_name = samples.apply(get_alternate_or_actual_name, axis=1)\n", " samples['city'] = random_alternate_name\n", " samples['location'] = samples.apply(combine_city_with_states, axis=1)\n", " return samples" ] }, { "cell_type": "code", "execution_count": null, "id": "04052587-fcb9-41bc-8533-7d08b9f689e4", "metadata": {}, "outputs": [], "source": [ "sample_df = sample_location(city_states_data, n_examples=100000, state_ratio=0.5)" ] }, { "cell_type": "code", "execution_count": null, "id": "57464006-4fae-44c6-907d-48b1b03fdb80", "metadata": {}, "outputs": [], "source": [ "sample_df" ] }, { "cell_type": "code", "execution_count": null, "id": "53af30a6-c612-4567-9377-c6fae129dfe6", "metadata": {}, "outputs": [], "source": [ "sample_df.loc[sample_df['location'] == 'san']" ] }, { "cell_type": "code", "execution_count": null, "id": "a4ea6b52-8bfa-4f07-84e3-072a73988f5a", "metadata": {}, "outputs": [], "source": [ "sample_df['location'].value_counts()[:60]" ] }, { "cell_type": "code", "execution_count": null, "id": "eea0b802-764e-4b94-a817-d14be1f5c661", "metadata": {}, "outputs": [], "source": [ "geo_city_state_data = sample_df['location'].values.tolist()\n", "print(len(geo_city_state_data))\n", "geo_city_state_data[:10]" ] }, { "cell_type": "code", "execution_count": null, "id": "c80504f1-4a31-4cd2-85c3-96da175074ba", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "85bdeff1-a3f2-443e-a31b-d80e836c6ebe", "metadata": {}, "outputs": [], "source": [ "# !python -m pip install onnxruntime" ] }, { "cell_type": "code", "execution_count": null, "id": "689e6844-2a90-4b7a-a9a5-bb298dce2b70", "metadata": {}, "outputs": [], "source": [ "# !python -m pip freeze| grep onnxruntime" ] }, { "cell_type": "code", "execution_count": null, "id": "fc61067c-6e8a-499a-9d08-07fb4fb0eb2f", "metadata": {}, "outputs": [], "source": [ "# !mkdir ../models" ] }, { "cell_type": "code", "execution_count": null, "id": "74bca5a8-0bb0-46c1-8429-598e172f34af", "metadata": {}, "outputs": [], "source": [ "import onnxruntime as ort\n", "import numpy as np\n", "from transformers import AutoTokenizer, BertTokenizer\n", "\n", "# Download the ONNX model\n", "# model_url = \"https://huggingface.co/Xenova/bert-base-NER/resolve/main/onnx/model_quantized.onnx\"\n", "# model_url = \"https://huggingface.co/Mozilla/distilbert-NER-LoRA/resolve/main/onnx/model_quantized.onnx\"\n", "# model_url = \"https://huggingface.co/Mozilla/distilbert-uncased-NER-LoRA/resolve/main/onnx/model_quantized.onnx\"\n", "model_url = \"https://huggingface.co/chidamnat2002/distilbert-uncased-NER-LoRA/resolve/main/onnx/model_quantized.onnx\"\n", "# model_path = \"../models/distilbert-NER-LoRA.onnx\"\n", "model_path = \"../models/distilbert-uncased-NER-LoRA.onnx\"\n", "\n", "# Download the ONNX model if not already present\n", "response = requests.get(model_url)\n", "with open(model_path, 'wb') as f:\n", " f.write(response.content)\n", "\n", "# Load the ONNX model using ONNX Runtime\n", "session = ort.InferenceSession(model_path)\n", "\n", "# Load the tokenizer (assuming it's based on BERT)\n", "# tokenizer = BertTokenizer.from_pretrained(\"Mozilla/distilbert-NER-LoRA\")\n", "# tokenizer = AutoTokenizer.from_pretrained(\"Mozilla/distilbert-uncased-NER-LoRA\")\n", "tokenizer = AutoTokenizer.from_pretrained(\"chidamnat2002/distilbert-uncased-NER-LoRA\")" ] }, { "cell_type": "code", "execution_count": null, "id": "838001d1-a252-4a4f-bfab-8c7698b7c79b", "metadata": {}, "outputs": [], "source": [ "def compute_model_inputs_and_outputs(session, tokenizer, query):\n", " # Tokenize the input\n", " # inputs = tokenizer(query, return_tensors=\"np\", truncation=True, padding=True)\n", " inputs = tokenizer(query, return_tensors=\"np\", truncation=True, padding='max_length', max_length=64)\n", " # is_split_into_words=True,\n", " # truncation=True,\n", " # padding='max_length',\n", " # max_length=64\n", " \n", " # The ONNX model expects 'input_ids', 'attention_mask', and 'token_type_ids'\n", " # Convert all necessary inputs to numpy arrays and prepare the input feed\n", " input_feed = {\n", " 'input_ids': inputs['input_ids'].astype(np.int64),\n", " 'attention_mask': inputs['attention_mask'].astype(np.int64),\n", " # 'token_type_ids': inputs['token_type_ids'].astype(np.int64) # Some models might not need this; check if it's really required\n", " }\n", " \n", " # Run inference with the ONNX model\n", " outputs = session.run(None, input_feed)\n", " # print(outputs)\n", " return inputs, outputs\n" ] }, { "cell_type": "code", "execution_count": null, "id": "638ac070-a689-4c06-a47b-0b9a21eb1373", "metadata": {}, "outputs": [], "source": [ "def detect_location(inputs, outputs, tokenizer):\n", " # print(\"Shape of outputs:\", [o.shape for o in outputs])\n", "\n", " # Post-process the output (this will depend on the model's output structure)\n", " logits = outputs[0] # Assuming the model output is logits\n", " probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)\n", " \n", " predicted_ids = np.argmax(logits, axis=-1)\n", " predicted_probs = np.max(probabilities, axis=-1)\n", " \n", " # Define the threshold for NER probability\n", " threshold = 0.5\n", " \n", " label_map = {\n", " 0: \"O\", # Outside any named entity\n", " 1: \"B-PER\", # Beginning of a person entity\n", " 2: \"I-PER\", # Inside a person entity\n", " 3: \"B-ORG\", # Beginning of an organization entity\n", " 4: \"I-ORG\", # Inside an organization entity\n", " 5: \"B-LOC\", # Beginning of a location entity\n", " 6: \"I-LOC\", # Inside a location entity\n", " 7: \"B-MISC\", # Beginning of a miscellaneous entity (for example)\n", " 8: \"I-MISC\" # Inside a miscellaneous entity (for example)\n", " }\n", " \n", " tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])\n", " \n", " # List to hold the detected location terms\n", " location_entities = []\n", " current_location = []\n", " \n", " # Loop through each token and its predicted label and probability\n", " for i, (token, predicted_id, prob) in enumerate(zip(tokens, predicted_ids[0], predicted_probs[0])):\n", " # for i, (token, predicted_id, prob) in enumerate(zip(tokens, predicted_ids.flatten(), predicted_probs.flatten())):\n", " label = label_map[predicted_id]\n", "\n", " # Ignore special tokens like [CLS], [SEP]\n", " if token in [\"[CLS]\", \"[SEP]\", \"[PAD]\"]:\n", " continue\n", " \n", " # Only consider tokens with probability above the threshold\n", " if prob > threshold:\n", " # If the token is a part of a location entity (B-LOC or I-LOC)\n", " if label in [\"B-LOC\", \"I-LOC\"]:\n", " if label == \"B-LOC\":\n", " # If we encounter a B-LOC, we may need to store the previous location\n", " if current_location:\n", " location_entities.append(\" \".join(current_location).replace(\"##\", \"\"))\n", " # Start a new location entity\n", " current_location = [token]\n", " elif label == \"I-LOC\" and current_location:\n", " # Continue appending to the current location entity\n", " current_location.append(token)\n", " else:\n", " # If we encounter a non-location entity, store the current location and reset\n", " if current_location:\n", " location_entities.append(\" \".join(current_location).replace(\"##\", \"\"))\n", " current_location = []\n", " \n", " # Append the last location entity if it exists\n", " if current_location:\n", " location_entities.append(\" \".join(current_location).replace(\"##\", \"\"))\n", "\n", " # Return the detected location terms\n", " return location_entities[0] if location_entities != [] else None\n" ] }, { "cell_type": "code", "execution_count": null, "id": "827c23a5-367d-4f9b-80f5-471cd3d4f40b", "metadata": {}, "outputs": [], "source": [ "# query = \"restaurants in Philadelphia\"\n", "query = \"weather Boston\"\n", "# query = \"Boston weather\"\n", "inputs, outputs = compute_model_inputs_and_outputs(session, tokenizer, query)\n", "detect_location(inputs, outputs, tokenizer)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "7d537628-8255-4a5b-901b-8338060d8c25", "metadata": {}, "outputs": [], "source": [ "# inputs\n", "outputs[0].shape" ] }, { "cell_type": "code", "execution_count": null, "id": "a78ffff9-6bd0-4032-9679-4c20d902a56d", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "cea8d2c3-aaab-48fb-86f1-8445e667af6d", "metadata": { "scrolled": true }, "outputs": [], "source": [ "num_examples = len(dma_data)\n", "hit = 0\n", "match = 0\n", "missing_locations = set()\n", "for index, dma in enumerate(dma_data, start=1):\n", " # location = detect_location(session, tokenizer, dma)\n", " inputs, outputs = compute_model_inputs_and_outputs(session, tokenizer, dma)\n", " location = detect_location(inputs, outputs, tokenizer)\n", " print(f\"{index}. {dma} -> {location}, : {dma.lower() == location}\")\n", " if location:\n", " hit += 1\n", " if dma.lower() == location:\n", " match += 1\n", " else:\n", " missing_locations.add(dma)\n", "\n", "print()\n", "print(f\"Number of examples = {num_examples}\")\n", "print(f\"#hits = {hit}; #hit rate = {hit/num_examples}\")\n", "print(f\"#matches = {match}; #match rate = {match/num_examples}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "bfa7d5a6-c13a-48e1-9cb2-97ad020f27f5", "metadata": {}, "outputs": [], "source": [ "# num_examples = len(geo_city_state_data)\n", "# hit = 0\n", "# match = 0\n", "# missing_locations = set()\n", "# for index, city_data in enumerate(geo_city_state_data, start=1):\n", "# # location = detect_location(session, tokenizer, city_data)\n", "# inputs, outputs = compute_model_inputs_and_outputs(session, tokenizer, city_data)\n", "# location = detect_location(inputs, outputs, tokenizer)\n", "# print(f\"{index}. {city_data} -> {location}, : {city_data == location}\")\n", "# if location:\n", "# hit += 1\n", "# if city_data == location:\n", "# match += 1\n", "# else:\n", "# missing_locations.add(city_data)\n", "\n", "# print()\n", "# print(f\"Number of examples = {num_examples}\")\n", "# print(f\"#hits = {hit}; #hit rate = {hit/num_examples}\")\n", "# print(f\"#matches = {match}; #match rate = {match/num_examples}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "08ecb315-3896-4a7e-8c03-37e3ecb1fa9a", "metadata": {}, "outputs": [], "source": [ "## With Xenova/bert-base-NER\n", "# Number of examples = 349\n", "# #hits = 135; #hit rate = 0.3868194842406877\n", "\n", "## After finetuning the Mozilla/distilbert-NER-LoRA\n", "#hits = 220; #hit rate = 0.6303724928366762\n", "\n", "## After finetuning the chidamnat2002/distilbert-uncased-NER-LoRA\n", "#hits = 207; #hit rate = 0.5931232091690545\n", "\n", "## After finetuning the Mozilla/distilbert-uncased-NER-LoRA\n", "#hits = 252; #hit rate = 0.7220630372492837" ] }, { "cell_type": "code", "execution_count": null, "id": "1eed2554-784c-4f49-aad5-72b795f19295", "metadata": {}, "outputs": [], "source": [ "len(missing_locations)" ] }, { "cell_type": "code", "execution_count": null, "id": "feaed0b3-5fb8-4686-b57a-3a8d9764ec79", "metadata": { "scrolled": true }, "outputs": [], "source": [ "print(missing_locations)" ] }, { "cell_type": "code", "execution_count": null, "id": "d04d5258-16b4-4773-b585-b5f31db3926c", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "ef09b219-dd01-4d66-92e2-c438935e8654", "metadata": {}, "source": [ "#### Looking into CONLL 2003 dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "4233afed-374f-4f2f-baaa-078447959367", "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset, Dataset\n", "import re\n", "\n", "# Load the CoNLL-2003 dataset\n", "dataset = load_dataset(\"conll2003\")\n", "\n", "loc_examples = dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "14216057-228f-467a-aa8e-02108d56cb92", "metadata": {}, "outputs": [], "source": [ "dataset['train'].to_pandas()" ] }, { "cell_type": "code", "execution_count": null, "id": "e259586a-f67b-42b2-9665-a571da352f57", "metadata": {}, "outputs": [], "source": [ "dataset['train']" ] }, { "cell_type": "code", "execution_count": null, "id": "b360becd-e584-4908-8b0b-c27291c5552a", "metadata": {}, "outputs": [], "source": [ "label_map = {\n", " 0: \"O\", # Outside any named entity\n", " 1: \"B-PER\", # Beginning of a person entity\n", " 2: \"I-PER\", # Inside a person entity\n", " 3: \"B-ORG\", # Beginning of an organization entity\n", " 4: \"I-ORG\", # Inside an organization entity\n", " 5: \"B-LOC\", # Beginning of a location entity\n", " 6: \"I-LOC\", # Inside a location entity\n", " 7: \"B-MISC\", # Beginning of a miscellaneous entity (for example)\n", " 8: \"I-MISC\" # Inside a miscellaneous entity (for example)\n", " }" ] }, { "cell_type": "code", "execution_count": null, "id": "9b7191c6-db07-4c51-98a4-9a408d988092", "metadata": {}, "outputs": [], "source": [ "import random\n", "import pandas as pd\n", "from collections import Counter\n", "\n", "# List of sample cities\n", "# cities = list(missing_locations)\n", "# cities = dma_data[:]\n", "cities = geo_city_state_data[:]\n", "NUM_EXAMPLES = 50000\n", "# Sample sentence templates\n", "templates = [\n", " \"John visited {} last summer.\",\n", " \"The headquarters is located in {}.\",\n", " \"My cousin moved to {} recently.\",\n", " \"{} is famous for its historical landmarks.\",\n", " \"A new park was opened in {}.\",\n", " \"The festival in {} was a great success.\",\n", " \"I am planning a trip to {} next month.\",\n", " \"The weather in {} has been wonderful this year.\",\n", " \"{} is known for its beautiful scenery.\",\n", " \"{} is home to several tech companies.\",\n", " # \"{} weather\",\n", " \"weather {}\",\n", "# # ]\n", "# # addtional_weather_yelp_templates = [\n", " 'The weather in {}',\n", " 'What is the weather in {}',\n", " \"What's the weather in {}\",\n", " 'Weather forecast in {}',\n", " '{} weather',\n", " 'temperature {}',\n", " '{} temperature',\n", " 'What are the best restaurants in {}',\n", " 'Top-rated restaurants in {}',\n", " 'Popular coffee shops in {}',\n", " 'Best pizza places in {}',\n", " 'Best sushi places in {}',\n", " 'Cheap restaurants in {}',\n", " 'Best places to eat in {}',\n", " 'Restaurants near me in {}',\n", " '{} restaurants',\n", " '{} hotels',\n", " '{} food',\n", "]\n", "\n", "print(f\"Size of templates = {len(templates)}\")\n", "\n", "# Function to create NER tags\n", "def create_ner_tags(tokens, city):\n", " ner_tags = []\n", " for token in tokens:\n", " if token in city.split():\n", " # Assign B-LOC for the first token of the city, and I-LOC for the rest\n", " ner_tag = 5 if city.split().index(token) == 0 else 6\n", " ner_tags.append(ner_tag)\n", " else:\n", " ner_tags.append(0) # O tag for non-entity words\n", " return ner_tags\n", "\n", "# Generate 10000 NER examples with IDs, tokens, and ner_tags\n", "ner_examples = []\n", "queries_set = set()\n", "pattern_counter = Counter()\n", "lower_case_prob = 0.4\n", "i = 0\n", "# for i in range(NUM_EXAMPLES):\n", "while i < NUM_EXAMPLES:\n", " if i % 1000 == 0:\n", " print(f\"completed {i+1} examples\")\n", " city = random.choice(cities)\n", " if random.random() < lower_case_prob:\n", " city = city.lower()\n", " # if i%2 == 0:\n", " # city = city.lower()\n", " template = random.choice(templates)\n", " sentence = template.format(city)\n", " if sentence in queries_set:\n", " continue\n", " if pattern_counter.get(template, 0) > NUM_EXAMPLES//6:\n", " continue\n", " queries_set.add(sentence)\n", " pattern_counter.update([template])\n", " tokens = sentence.split()\n", " ner_tags = create_ner_tags(tokens, city)\n", " \n", " # Append the example in the format of {'id', 'tokens', 'ner_tags'}\n", " ner_examples.append({\n", " 'id': str(i),\n", " 'tokens': tokens,\n", " 'ner_tags': ner_tags\n", " })\n", " i += 1\n", "\n", "\n", "# Convert the examples into a pandas DataFrame\n", "df_ner_examples = pd.DataFrame(ner_examples)\n", "df_ner_examples" ] }, { "cell_type": "code", "execution_count": null, "id": "12e91919-6dc4-4ad3-a388-e5b90d4efa79", "metadata": {}, "outputs": [], "source": [ "synthetic_loc_dataset = Dataset.from_pandas(df_ner_examples)\n", "synthetic_loc_dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "0d91ba34-cb67-418a-8a4e-4b442b144be6", "metadata": {}, "outputs": [], "source": [ "synthetic_loc_dataset[0]" ] }, { "cell_type": "code", "execution_count": null, "id": "496a76a7-3329-4849-affa-63166d427183", "metadata": {}, "outputs": [], "source": [ "# loc_dataset = dataset['train'].filter(lambda example: 5 in example['ner_tags'])\n", "loc_dataset = dataset['train']\n", "loc_dataset_filtered = loc_dataset.remove_columns(['pos_tags', 'chunk_tags'])\n", "\n", "# Set the format to ensure the order is 'id', 'tokens', and 'ner_tags'\n", "loc_dataset_filtered[0]" ] }, { "cell_type": "code", "execution_count": null, "id": "42652aaf-399f-413f-a8f6-e082f1057e3f", "metadata": {}, "outputs": [], "source": [ "loc_dataset_filtered[-1]" ] }, { "cell_type": "code", "execution_count": null, "id": "c47584e0-0612-400b-81e9-212a61209b94", "metadata": {}, "outputs": [], "source": [ "from datasets import concatenate_datasets\n", "\n", "from datasets import Sequence, ClassLabel, Value\n", "\n", "# Step 1: Get the full feature schema from synthetic_loc_dataset\n", "features = synthetic_loc_dataset.features\n", "\n", "# Step 2: Update the 'ner_tags' feature to use ClassLabel from loc_dataset_filtered\n", "features['ner_tags'] = Sequence(feature=ClassLabel(names=loc_dataset_filtered.features['ner_tags'].feature.names))\n", "\n", "# Step 3: Cast synthetic_loc_dataset to the updated feature schema\n", "synthetic_loc_dataset = synthetic_loc_dataset.cast(features)\n", "\n", "# Check the updated features to confirm\n", "print(synthetic_loc_dataset.features)\n", "\n", "# Now concatenate the datasets\n", "combined_dataset = concatenate_datasets([loc_dataset_filtered, synthetic_loc_dataset])\n", "\n", "# Verify the combined dataset\n", "print(combined_dataset[0])\n" ] }, { "cell_type": "code", "execution_count": null, "id": "6138a427-f03b-4355-bdac-ffec783f5a2b", "metadata": {}, "outputs": [], "source": [ "len(combined_dataset)" ] }, { "cell_type": "code", "execution_count": null, "id": "caac8e36-6d1c-4a42-8acd-7e81f816fa9b", "metadata": {}, "outputs": [], "source": [ "combined_dataset[3]" ] }, { "cell_type": "code", "execution_count": null, "id": "2aa98e69-bf5f-4bcc-b387-2abdc60a99be", "metadata": {}, "outputs": [], "source": [ "combined_dataset = combined_dataset.map(\n", " lambda example, idx: {'id': idx}, # Assign running count as the new 'id'\n", " with_indices=True # Ensures we get an index for each example\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "5906e294-6a1b-436d-a229-628f99190887", "metadata": {}, "outputs": [], "source": [ "combined_dataset.to_pandas()" ] }, { "cell_type": "code", "execution_count": null, "id": "46c0d423-3b8c-47ed-a8ae-a3316cd78bd0", "metadata": {}, "outputs": [], "source": [ "combined_dataset[-1]" ] }, { "cell_type": "code", "execution_count": null, "id": "c35b1a0b-303c-4eee-bc31-770872c212e5", "metadata": {}, "outputs": [], "source": [ "combined_dataset.to_parquet(\"../data/combined_ner_examples_v3.parquet\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d33bb9a1-bd49-49cd-aa90-5428d46fbad7", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoTokenizer, AutoModelForTokenClassification\n", "from transformers import pipeline\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"Mozilla/distilbert-uncased-NER-LoRA\")\n", "model = AutoModelForTokenClassification.from_pretrained(\"Mozilla/distilbert-uncased-NER-LoRA\")\n", "\n", "nlp = pipeline(\"ner\", model=model, tokenizer=tokenizer)\n", "example = \"New York\"\n", "\n", "ner_results = nlp(example)\n", "print(ner_results)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "27929164-3156-4ddf-b878-26d628daeace", "metadata": {}, "outputs": [], "source": [ "len(ner_examples)" ] }, { "cell_type": "code", "execution_count": null, "id": "5eb84c94-a94e-4f6c-976c-458b2d9a1a0d", "metadata": {}, "outputs": [], "source": [ "example = ' '.join( ner_examples[1]['tokens'])\n", "example" ] }, { "cell_type": "code", "execution_count": null, "id": "6c61e8bc-92f4-452d-b76d-ec2ed35b3963", "metadata": {}, "outputs": [], "source": [ "sample_inputs = tokenizer(example, return_tensors=\"np\", truncation=True, padding='max_length', max_length=64)\n", "sample_inputs['input_ids']" ] }, { "cell_type": "code", "execution_count": null, "id": "3d3492e4-8783-43d2-b4ac-2b8d652f1324", "metadata": {}, "outputs": [], "source": [ "tokenizer.decode(tokenizer(example, return_tensors=\"np\", truncation=True, padding='max_length', max_length=64)['input_ids'][0])" ] }, { "cell_type": "code", "execution_count": null, "id": "8d121130-81b3-48b4-ba70-7f3926d17ac6", "metadata": {}, "outputs": [], "source": [ "tokenizer.vocab['land']" ] }, { "cell_type": "code", "execution_count": null, "id": "0fac046c-a5f6-471c-9c41-1aac0dab439d", "metadata": {}, "outputs": [], "source": [ "df_ner_examples" ] }, { "cell_type": "code", "execution_count": null, "id": "32524933-23f7-41ae-8597-da0300e6ac60", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 5 }