notebooks/city_weather_exploration_and_dataprep.ipynb (1,024 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"id": "73d1863d-1d54-4cdd-843c-c033b28f15f6",
"metadata": {},
"source": [
"Explore whether the weather keywords and locations are captured correctly"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bd4805cc-8d46-40fa-8d39-35158d9212d4",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import requests\n",
"from bs4 import BeautifulSoup\n",
"import re"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b64db933-17ab-47cc-b0ba-ae37e89e450a",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import random"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d70bc639-b4de-4544-bd0f-f18a0b263a66",
"metadata": {},
"outputs": [],
"source": [
"url = \"https://en.m.wikipedia.org/wiki/List_of_television_stations_in_North_America_by_media_market\"\n",
"response = requests.get(url)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a1c49059-f982-46a5-a871-aeb2ec2a6688",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"if response.status_code == 200:\n",
" soup = BeautifulSoup(response.content, 'html.parser')\n",
" dma_heading = soup.find('h4', string='DMAs')\n",
" dma_list = dma_heading.find_next('ul')\n",
" \n",
" dma_data = []\n",
" if dma_list:\n",
" for li in dma_list.find_all('li'):\n",
" market_name = li.get_text(strip=True)\n",
"\n",
" # Split by dash (-) or en-dash (–) to handle cases like \"Dallas-Fort Worth\"\n",
" split_names = re.split(r'–|-', market_name)\n",
"\n",
" # Process each split name\n",
" for name in split_names:\n",
" # Remove the (#NUM) part using regex\n",
" name = re.sub(r'\\s*\\(#\\d+\\)', '', name).strip()\n",
"\n",
" # Check if there's a city in parentheses and split them\n",
" match = re.match(r'(.+?)\\s*\\((.+?)\\)', name)\n",
" if match:\n",
" main_city = match.group(1).strip()\n",
" parenthetical_city = match.group(2).strip()\n",
" dma_data.append(main_city) # Add the main city\n",
" dma_data.append(parenthetical_city) # Add the city in parentheses\n",
" else:\n",
" dma_data.append(name) \n",
"\n",
" for index, dma in enumerate(dma_data, start=1):\n",
" print(f\"{index}. {dma}\")\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "52f9543e-4b78-46f1-828f-8f49340a4be0",
"metadata": {},
"outputs": [],
"source": [
"dma_data[:5]"
]
},
{
"cell_type": "markdown",
"id": "8bcf91d7-8344-4b5e-9641-461b2630cb0f",
"metadata": {},
"source": [
"#### Read the data/geonames-cities-states.json"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "738661a5-668f-4b2c-8823-dc3c0c92be94",
"metadata": {},
"outputs": [],
"source": [
"import json \n",
"\n",
"def get_geonames_city_state_data():\n",
" geonames_file = \"../data/geonames-cities-states.json\"\n",
" with open(geonames_file, 'r') as f:\n",
" geonames_dict = json.load(f)\n",
" \n",
" \n",
" cities_data = pd.DataFrame(geonames_dict['cities'])\\\n",
" .rename(columns={'admin1_code': 'state_code', 'name': 'city_name', 'population': 'city_popln'})\n",
" cities_data = cities_data[['id', 'state_code', 'city_name', 'city_popln', 'alternate_names']]\n",
" states_data = pd.DataFrame(geonames_dict['states_by_abbr'].values())\\\n",
" .rename(columns={'admin1_code': 'state_code', 'name': 'state_name'})\n",
" states_data = states_data[['state_code', 'state_name']]\n",
" city_states_data = cities_data.merge(states_data, how='left', on='state_code')\n",
" city_states_data['city_weight'] = city_states_data['city_popln'] / city_states_data['city_popln'].sum()\n",
" return city_states_data\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a3aeb4bd-2e84-4121-84b7-8ffb1118ca37",
"metadata": {},
"outputs": [],
"source": [
"city_states_data = get_geonames_city_state_data()\n",
"print(len(city_states_data))\n",
"city_states_data"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d35076ae-1d45-4699-8257-e98612500e43",
"metadata": {},
"outputs": [],
"source": [
"city_states_data.sort_values('city_weight', ascending=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "df043822-779c-4f9c-89eb-b331e2b0de19",
"metadata": {},
"outputs": [],
"source": [
"# useful for post processing to standardize the city names\n",
"def build_lookup(dataframe):\n",
" # Initialize an empty dictionary for the lookup\n",
" lookup = {}\n",
" \n",
" # Iterate over each row in the DataFrame\n",
" for index, row in dataframe.iterrows():\n",
" city_name = row['city_name']\n",
" alternate_names = row['alternate_names']\n",
" \n",
" # Iterate over the list of alternate names and map them to the city_name\n",
" for alt_name in alternate_names:\n",
" lookup[alt_name.lower()] = city_name # Convert alternate names to lowercase for consistency\n",
" \n",
" return lookup\n",
"\n",
"city_alternate_to_city_lkp = build_lookup(city_states_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "62a392e3-e18e-470f-9f95-ad35ebaebca8",
"metadata": {},
"outputs": [],
"source": [
"len(city_alternate_to_city_lkp)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1f9d1453-8955-4dec-b0e2-4e8d29a82046",
"metadata": {},
"outputs": [],
"source": [
"city_states_data['alternate_names'].apply(len).value_counts()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "dc9cc8fe-95ad-45e1-8bd8-134faf7aa37d",
"metadata": {},
"outputs": [],
"source": [
"np.random.seed(42)\n",
"\n",
"def get_alternate_or_actual_name(row):\n",
" if row['alternate_names'] and isinstance(row['alternate_names'], list):\n",
" return random.choice(row['alternate_names'])\n",
" return row['city_name']\n",
"\n",
"def combine_city_with_states(row):\n",
" if row['state_code'] is not None:\n",
" # return row['city'] + \", \" + row['state_code']\n",
" return row['city'] + \", \" + random.choice([row['state_code'], row['state_name']])\n",
" return row['city']\n",
" \n",
"def sample_location(df, n_examples=10000, state_ratio=0.5):\n",
" weights = df['city_weight']\n",
" samples = df[['id', 'city_name', 'alternate_names', 'state_code', 'state_name', 'city_popln']].sample(n=n_examples, weights=weights, replace=True)\n",
" states_idx = np.random.random(n_examples) <= state_ratio\n",
" samples.loc[states_idx, 'state_code'] = None\n",
" random_alternate_name = samples.apply(get_alternate_or_actual_name, axis=1)\n",
" samples['city'] = random_alternate_name\n",
" samples['location'] = samples.apply(combine_city_with_states, axis=1)\n",
" return samples"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "04052587-fcb9-41bc-8533-7d08b9f689e4",
"metadata": {},
"outputs": [],
"source": [
"sample_df = sample_location(city_states_data, n_examples=100000, state_ratio=0.5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "57464006-4fae-44c6-907d-48b1b03fdb80",
"metadata": {},
"outputs": [],
"source": [
"sample_df"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "53af30a6-c612-4567-9377-c6fae129dfe6",
"metadata": {},
"outputs": [],
"source": [
"sample_df.loc[sample_df['location'] == 'san']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a4ea6b52-8bfa-4f07-84e3-072a73988f5a",
"metadata": {},
"outputs": [],
"source": [
"sample_df['location'].value_counts()[:60]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "eea0b802-764e-4b94-a817-d14be1f5c661",
"metadata": {},
"outputs": [],
"source": [
"geo_city_state_data = sample_df['location'].values.tolist()\n",
"print(len(geo_city_state_data))\n",
"geo_city_state_data[:10]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c80504f1-4a31-4cd2-85c3-96da175074ba",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "85bdeff1-a3f2-443e-a31b-d80e836c6ebe",
"metadata": {},
"outputs": [],
"source": [
"# !python -m pip install onnxruntime"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "689e6844-2a90-4b7a-a9a5-bb298dce2b70",
"metadata": {},
"outputs": [],
"source": [
"# !python -m pip freeze| grep onnxruntime"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "fc61067c-6e8a-499a-9d08-07fb4fb0eb2f",
"metadata": {},
"outputs": [],
"source": [
"# !mkdir ../models"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "74bca5a8-0bb0-46c1-8429-598e172f34af",
"metadata": {},
"outputs": [],
"source": [
"import onnxruntime as ort\n",
"import numpy as np\n",
"from transformers import AutoTokenizer, BertTokenizer\n",
"\n",
"# Download the ONNX model\n",
"# model_url = \"https://huggingface.co/Xenova/bert-base-NER/resolve/main/onnx/model_quantized.onnx\"\n",
"# model_url = \"https://huggingface.co/Mozilla/distilbert-NER-LoRA/resolve/main/onnx/model_quantized.onnx\"\n",
"# model_url = \"https://huggingface.co/Mozilla/distilbert-uncased-NER-LoRA/resolve/main/onnx/model_quantized.onnx\"\n",
"model_url = \"https://huggingface.co/chidamnat2002/distilbert-uncased-NER-LoRA/resolve/main/onnx/model_quantized.onnx\"\n",
"# model_path = \"../models/distilbert-NER-LoRA.onnx\"\n",
"model_path = \"../models/distilbert-uncased-NER-LoRA.onnx\"\n",
"\n",
"# Download the ONNX model if not already present\n",
"response = requests.get(model_url)\n",
"with open(model_path, 'wb') as f:\n",
" f.write(response.content)\n",
"\n",
"# Load the ONNX model using ONNX Runtime\n",
"session = ort.InferenceSession(model_path)\n",
"\n",
"# Load the tokenizer (assuming it's based on BERT)\n",
"# tokenizer = BertTokenizer.from_pretrained(\"Mozilla/distilbert-NER-LoRA\")\n",
"# tokenizer = AutoTokenizer.from_pretrained(\"Mozilla/distilbert-uncased-NER-LoRA\")\n",
"tokenizer = AutoTokenizer.from_pretrained(\"chidamnat2002/distilbert-uncased-NER-LoRA\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "838001d1-a252-4a4f-bfab-8c7698b7c79b",
"metadata": {},
"outputs": [],
"source": [
"def compute_model_inputs_and_outputs(session, tokenizer, query):\n",
" # Tokenize the input\n",
" # inputs = tokenizer(query, return_tensors=\"np\", truncation=True, padding=True)\n",
" inputs = tokenizer(query, return_tensors=\"np\", truncation=True, padding='max_length', max_length=64)\n",
" # is_split_into_words=True,\n",
" # truncation=True,\n",
" # padding='max_length',\n",
" # max_length=64\n",
" \n",
" # The ONNX model expects 'input_ids', 'attention_mask', and 'token_type_ids'\n",
" # Convert all necessary inputs to numpy arrays and prepare the input feed\n",
" input_feed = {\n",
" 'input_ids': inputs['input_ids'].astype(np.int64),\n",
" 'attention_mask': inputs['attention_mask'].astype(np.int64),\n",
" # 'token_type_ids': inputs['token_type_ids'].astype(np.int64) # Some models might not need this; check if it's really required\n",
" }\n",
" \n",
" # Run inference with the ONNX model\n",
" outputs = session.run(None, input_feed)\n",
" # print(outputs)\n",
" return inputs, outputs\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "638ac070-a689-4c06-a47b-0b9a21eb1373",
"metadata": {},
"outputs": [],
"source": [
"def detect_location(inputs, outputs, tokenizer):\n",
" # print(\"Shape of outputs:\", [o.shape for o in outputs])\n",
"\n",
" # Post-process the output (this will depend on the model's output structure)\n",
" logits = outputs[0] # Assuming the model output is logits\n",
" probabilities = np.exp(logits) / np.sum(np.exp(logits), axis=-1, keepdims=True)\n",
" \n",
" predicted_ids = np.argmax(logits, axis=-1)\n",
" predicted_probs = np.max(probabilities, axis=-1)\n",
" \n",
" # Define the threshold for NER probability\n",
" threshold = 0.5\n",
" \n",
" label_map = {\n",
" 0: \"O\", # Outside any named entity\n",
" 1: \"B-PER\", # Beginning of a person entity\n",
" 2: \"I-PER\", # Inside a person entity\n",
" 3: \"B-ORG\", # Beginning of an organization entity\n",
" 4: \"I-ORG\", # Inside an organization entity\n",
" 5: \"B-LOC\", # Beginning of a location entity\n",
" 6: \"I-LOC\", # Inside a location entity\n",
" 7: \"B-MISC\", # Beginning of a miscellaneous entity (for example)\n",
" 8: \"I-MISC\" # Inside a miscellaneous entity (for example)\n",
" }\n",
" \n",
" tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])\n",
" \n",
" # List to hold the detected location terms\n",
" location_entities = []\n",
" current_location = []\n",
" \n",
" # Loop through each token and its predicted label and probability\n",
" for i, (token, predicted_id, prob) in enumerate(zip(tokens, predicted_ids[0], predicted_probs[0])):\n",
" # for i, (token, predicted_id, prob) in enumerate(zip(tokens, predicted_ids.flatten(), predicted_probs.flatten())):\n",
" label = label_map[predicted_id]\n",
"\n",
" # Ignore special tokens like [CLS], [SEP]\n",
" if token in [\"[CLS]\", \"[SEP]\", \"[PAD]\"]:\n",
" continue\n",
" \n",
" # Only consider tokens with probability above the threshold\n",
" if prob > threshold:\n",
" # If the token is a part of a location entity (B-LOC or I-LOC)\n",
" if label in [\"B-LOC\", \"I-LOC\"]:\n",
" if label == \"B-LOC\":\n",
" # If we encounter a B-LOC, we may need to store the previous location\n",
" if current_location:\n",
" location_entities.append(\" \".join(current_location).replace(\"##\", \"\"))\n",
" # Start a new location entity\n",
" current_location = [token]\n",
" elif label == \"I-LOC\" and current_location:\n",
" # Continue appending to the current location entity\n",
" current_location.append(token)\n",
" else:\n",
" # If we encounter a non-location entity, store the current location and reset\n",
" if current_location:\n",
" location_entities.append(\" \".join(current_location).replace(\"##\", \"\"))\n",
" current_location = []\n",
" \n",
" # Append the last location entity if it exists\n",
" if current_location:\n",
" location_entities.append(\" \".join(current_location).replace(\"##\", \"\"))\n",
"\n",
" # Return the detected location terms\n",
" return location_entities[0] if location_entities != [] else None\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "827c23a5-367d-4f9b-80f5-471cd3d4f40b",
"metadata": {},
"outputs": [],
"source": [
"# query = \"restaurants in Philadelphia\"\n",
"query = \"weather Boston\"\n",
"# query = \"Boston weather\"\n",
"inputs, outputs = compute_model_inputs_and_outputs(session, tokenizer, query)\n",
"detect_location(inputs, outputs, tokenizer)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7d537628-8255-4a5b-901b-8338060d8c25",
"metadata": {},
"outputs": [],
"source": [
"# inputs\n",
"outputs[0].shape"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a78ffff9-6bd0-4032-9679-4c20d902a56d",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "cea8d2c3-aaab-48fb-86f1-8445e667af6d",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"num_examples = len(dma_data)\n",
"hit = 0\n",
"match = 0\n",
"missing_locations = set()\n",
"for index, dma in enumerate(dma_data, start=1):\n",
" # location = detect_location(session, tokenizer, dma)\n",
" inputs, outputs = compute_model_inputs_and_outputs(session, tokenizer, dma)\n",
" location = detect_location(inputs, outputs, tokenizer)\n",
" print(f\"{index}. {dma} -> {location}, : {dma.lower() == location}\")\n",
" if location:\n",
" hit += 1\n",
" if dma.lower() == location:\n",
" match += 1\n",
" else:\n",
" missing_locations.add(dma)\n",
"\n",
"print()\n",
"print(f\"Number of examples = {num_examples}\")\n",
"print(f\"#hits = {hit}; #hit rate = {hit/num_examples}\")\n",
"print(f\"#matches = {match}; #match rate = {match/num_examples}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bfa7d5a6-c13a-48e1-9cb2-97ad020f27f5",
"metadata": {},
"outputs": [],
"source": [
"# num_examples = len(geo_city_state_data)\n",
"# hit = 0\n",
"# match = 0\n",
"# missing_locations = set()\n",
"# for index, city_data in enumerate(geo_city_state_data, start=1):\n",
"# # location = detect_location(session, tokenizer, city_data)\n",
"# inputs, outputs = compute_model_inputs_and_outputs(session, tokenizer, city_data)\n",
"# location = detect_location(inputs, outputs, tokenizer)\n",
"# print(f\"{index}. {city_data} -> {location}, : {city_data == location}\")\n",
"# if location:\n",
"# hit += 1\n",
"# if city_data == location:\n",
"# match += 1\n",
"# else:\n",
"# missing_locations.add(city_data)\n",
"\n",
"# print()\n",
"# print(f\"Number of examples = {num_examples}\")\n",
"# print(f\"#hits = {hit}; #hit rate = {hit/num_examples}\")\n",
"# print(f\"#matches = {match}; #match rate = {match/num_examples}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "08ecb315-3896-4a7e-8c03-37e3ecb1fa9a",
"metadata": {},
"outputs": [],
"source": [
"## With Xenova/bert-base-NER\n",
"# Number of examples = 349\n",
"# #hits = 135; #hit rate = 0.3868194842406877\n",
"\n",
"## After finetuning the Mozilla/distilbert-NER-LoRA\n",
"#hits = 220; #hit rate = 0.6303724928366762\n",
"\n",
"## After finetuning the chidamnat2002/distilbert-uncased-NER-LoRA\n",
"#hits = 207; #hit rate = 0.5931232091690545\n",
"\n",
"## After finetuning the Mozilla/distilbert-uncased-NER-LoRA\n",
"#hits = 252; #hit rate = 0.7220630372492837"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1eed2554-784c-4f49-aad5-72b795f19295",
"metadata": {},
"outputs": [],
"source": [
"len(missing_locations)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "feaed0b3-5fb8-4686-b57a-3a8d9764ec79",
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"print(missing_locations)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d04d5258-16b4-4773-b585-b5f31db3926c",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "markdown",
"id": "ef09b219-dd01-4d66-92e2-c438935e8654",
"metadata": {},
"source": [
"#### Looking into CONLL 2003 dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "4233afed-374f-4f2f-baaa-078447959367",
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset, Dataset\n",
"import re\n",
"\n",
"# Load the CoNLL-2003 dataset\n",
"dataset = load_dataset(\"conll2003\")\n",
"\n",
"loc_examples = dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "14216057-228f-467a-aa8e-02108d56cb92",
"metadata": {},
"outputs": [],
"source": [
"dataset['train'].to_pandas()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e259586a-f67b-42b2-9665-a571da352f57",
"metadata": {},
"outputs": [],
"source": [
"dataset['train']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b360becd-e584-4908-8b0b-c27291c5552a",
"metadata": {},
"outputs": [],
"source": [
"label_map = {\n",
" 0: \"O\", # Outside any named entity\n",
" 1: \"B-PER\", # Beginning of a person entity\n",
" 2: \"I-PER\", # Inside a person entity\n",
" 3: \"B-ORG\", # Beginning of an organization entity\n",
" 4: \"I-ORG\", # Inside an organization entity\n",
" 5: \"B-LOC\", # Beginning of a location entity\n",
" 6: \"I-LOC\", # Inside a location entity\n",
" 7: \"B-MISC\", # Beginning of a miscellaneous entity (for example)\n",
" 8: \"I-MISC\" # Inside a miscellaneous entity (for example)\n",
" }"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9b7191c6-db07-4c51-98a4-9a408d988092",
"metadata": {},
"outputs": [],
"source": [
"import random\n",
"import pandas as pd\n",
"from collections import Counter\n",
"\n",
"# List of sample cities\n",
"# cities = list(missing_locations)\n",
"# cities = dma_data[:]\n",
"cities = geo_city_state_data[:]\n",
"NUM_EXAMPLES = 50000\n",
"# Sample sentence templates\n",
"templates = [\n",
" \"John visited {} last summer.\",\n",
" \"The headquarters is located in {}.\",\n",
" \"My cousin moved to {} recently.\",\n",
" \"{} is famous for its historical landmarks.\",\n",
" \"A new park was opened in {}.\",\n",
" \"The festival in {} was a great success.\",\n",
" \"I am planning a trip to {} next month.\",\n",
" \"The weather in {} has been wonderful this year.\",\n",
" \"{} is known for its beautiful scenery.\",\n",
" \"{} is home to several tech companies.\",\n",
" # \"{} weather\",\n",
" \"weather {}\",\n",
"# # ]\n",
"# # addtional_weather_yelp_templates = [\n",
" 'The weather in {}',\n",
" 'What is the weather in {}',\n",
" \"What's the weather in {}\",\n",
" 'Weather forecast in {}',\n",
" '{} weather',\n",
" 'temperature {}',\n",
" '{} temperature',\n",
" 'What are the best restaurants in {}',\n",
" 'Top-rated restaurants in {}',\n",
" 'Popular coffee shops in {}',\n",
" 'Best pizza places in {}',\n",
" 'Best sushi places in {}',\n",
" 'Cheap restaurants in {}',\n",
" 'Best places to eat in {}',\n",
" 'Restaurants near me in {}',\n",
" '{} restaurants',\n",
" '{} hotels',\n",
" '{} food',\n",
"]\n",
"\n",
"print(f\"Size of templates = {len(templates)}\")\n",
"\n",
"# Function to create NER tags\n",
"def create_ner_tags(tokens, city):\n",
" ner_tags = []\n",
" for token in tokens:\n",
" if token in city.split():\n",
" # Assign B-LOC for the first token of the city, and I-LOC for the rest\n",
" ner_tag = 5 if city.split().index(token) == 0 else 6\n",
" ner_tags.append(ner_tag)\n",
" else:\n",
" ner_tags.append(0) # O tag for non-entity words\n",
" return ner_tags\n",
"\n",
"# Generate 10000 NER examples with IDs, tokens, and ner_tags\n",
"ner_examples = []\n",
"queries_set = set()\n",
"pattern_counter = Counter()\n",
"lower_case_prob = 0.4\n",
"i = 0\n",
"# for i in range(NUM_EXAMPLES):\n",
"while i < NUM_EXAMPLES:\n",
" if i % 1000 == 0:\n",
" print(f\"completed {i+1} examples\")\n",
" city = random.choice(cities)\n",
" if random.random() < lower_case_prob:\n",
" city = city.lower()\n",
" # if i%2 == 0:\n",
" # city = city.lower()\n",
" template = random.choice(templates)\n",
" sentence = template.format(city)\n",
" if sentence in queries_set:\n",
" continue\n",
" if pattern_counter.get(template, 0) > NUM_EXAMPLES//6:\n",
" continue\n",
" queries_set.add(sentence)\n",
" pattern_counter.update([template])\n",
" tokens = sentence.split()\n",
" ner_tags = create_ner_tags(tokens, city)\n",
" \n",
" # Append the example in the format of {'id', 'tokens', 'ner_tags'}\n",
" ner_examples.append({\n",
" 'id': str(i),\n",
" 'tokens': tokens,\n",
" 'ner_tags': ner_tags\n",
" })\n",
" i += 1\n",
"\n",
"\n",
"# Convert the examples into a pandas DataFrame\n",
"df_ner_examples = pd.DataFrame(ner_examples)\n",
"df_ner_examples"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "12e91919-6dc4-4ad3-a388-e5b90d4efa79",
"metadata": {},
"outputs": [],
"source": [
"synthetic_loc_dataset = Dataset.from_pandas(df_ner_examples)\n",
"synthetic_loc_dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0d91ba34-cb67-418a-8a4e-4b442b144be6",
"metadata": {},
"outputs": [],
"source": [
"synthetic_loc_dataset[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "496a76a7-3329-4849-affa-63166d427183",
"metadata": {},
"outputs": [],
"source": [
"# loc_dataset = dataset['train'].filter(lambda example: 5 in example['ner_tags'])\n",
"loc_dataset = dataset['train']\n",
"loc_dataset_filtered = loc_dataset.remove_columns(['pos_tags', 'chunk_tags'])\n",
"\n",
"# Set the format to ensure the order is 'id', 'tokens', and 'ner_tags'\n",
"loc_dataset_filtered[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "42652aaf-399f-413f-a8f6-e082f1057e3f",
"metadata": {},
"outputs": [],
"source": [
"loc_dataset_filtered[-1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c47584e0-0612-400b-81e9-212a61209b94",
"metadata": {},
"outputs": [],
"source": [
"from datasets import concatenate_datasets\n",
"\n",
"from datasets import Sequence, ClassLabel, Value\n",
"\n",
"# Step 1: Get the full feature schema from synthetic_loc_dataset\n",
"features = synthetic_loc_dataset.features\n",
"\n",
"# Step 2: Update the 'ner_tags' feature to use ClassLabel from loc_dataset_filtered\n",
"features['ner_tags'] = Sequence(feature=ClassLabel(names=loc_dataset_filtered.features['ner_tags'].feature.names))\n",
"\n",
"# Step 3: Cast synthetic_loc_dataset to the updated feature schema\n",
"synthetic_loc_dataset = synthetic_loc_dataset.cast(features)\n",
"\n",
"# Check the updated features to confirm\n",
"print(synthetic_loc_dataset.features)\n",
"\n",
"# Now concatenate the datasets\n",
"combined_dataset = concatenate_datasets([loc_dataset_filtered, synthetic_loc_dataset])\n",
"\n",
"# Verify the combined dataset\n",
"print(combined_dataset[0])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6138a427-f03b-4355-bdac-ffec783f5a2b",
"metadata": {},
"outputs": [],
"source": [
"len(combined_dataset)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "caac8e36-6d1c-4a42-8acd-7e81f816fa9b",
"metadata": {},
"outputs": [],
"source": [
"combined_dataset[3]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2aa98e69-bf5f-4bcc-b387-2abdc60a99be",
"metadata": {},
"outputs": [],
"source": [
"combined_dataset = combined_dataset.map(\n",
" lambda example, idx: {'id': idx}, # Assign running count as the new 'id'\n",
" with_indices=True # Ensures we get an index for each example\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5906e294-6a1b-436d-a229-628f99190887",
"metadata": {},
"outputs": [],
"source": [
"combined_dataset.to_pandas()"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "46c0d423-3b8c-47ed-a8ae-a3316cd78bd0",
"metadata": {},
"outputs": [],
"source": [
"combined_dataset[-1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c35b1a0b-303c-4eee-bc31-770872c212e5",
"metadata": {},
"outputs": [],
"source": [
"combined_dataset.to_parquet(\"../data/combined_ner_examples_v3.parquet\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d33bb9a1-bd49-49cd-aa90-5428d46fbad7",
"metadata": {},
"outputs": [],
"source": [
"from transformers import AutoTokenizer, AutoModelForTokenClassification\n",
"from transformers import pipeline\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(\"Mozilla/distilbert-uncased-NER-LoRA\")\n",
"model = AutoModelForTokenClassification.from_pretrained(\"Mozilla/distilbert-uncased-NER-LoRA\")\n",
"\n",
"nlp = pipeline(\"ner\", model=model, tokenizer=tokenizer)\n",
"example = \"New York\"\n",
"\n",
"ner_results = nlp(example)\n",
"print(ner_results)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "27929164-3156-4ddf-b878-26d628daeace",
"metadata": {},
"outputs": [],
"source": [
"len(ner_examples)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5eb84c94-a94e-4f6c-976c-458b2d9a1a0d",
"metadata": {},
"outputs": [],
"source": [
"example = ' '.join( ner_examples[1]['tokens'])\n",
"example"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6c61e8bc-92f4-452d-b76d-ec2ed35b3963",
"metadata": {},
"outputs": [],
"source": [
"sample_inputs = tokenizer(example, return_tensors=\"np\", truncation=True, padding='max_length', max_length=64)\n",
"sample_inputs['input_ids']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3d3492e4-8783-43d2-b4ac-2b8d652f1324",
"metadata": {},
"outputs": [],
"source": [
"tokenizer.decode(tokenizer(example, return_tensors=\"np\", truncation=True, padding='max_length', max_length=64)['input_ids'][0])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8d121130-81b3-48b4-ba70-7f3926d17ac6",
"metadata": {},
"outputs": [],
"source": [
"tokenizer.vocab['land']"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0fac046c-a5f6-471c-9c41-1aac0dab439d",
"metadata": {},
"outputs": [],
"source": [
"df_ner_examples"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "32524933-23f7-41ae-8597-da0300e6ac60",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}