notebooks/city_state_exploration_and_dataprep.ipynb (2,700 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "id": "73d1863d-1d54-4cdd-843c-c033b28f15f6", "metadata": {}, "source": [ "Explore whether the weather keywords and locations are captured correctly" ] }, { "cell_type": "code", "execution_count": 1, "id": "bd4805cc-8d46-40fa-8d39-35158d9212d4", "metadata": {}, "outputs": [], "source": [ "import pandas as pd\n", "import requests\n", "from bs4 import BeautifulSoup\n", "import re\n", "from datasets import load_dataset, Dataset" ] }, { "cell_type": "code", "execution_count": 2, "id": "b64db933-17ab-47cc-b0ba-ae37e89e450a", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import random\n", "from collections import Counter" ] }, { "cell_type": "markdown", "id": "8bcf91d7-8344-4b5e-9641-461b2630cb0f", "metadata": {}, "source": [ "#### Read the data/geonames-cities-states.json" ] }, { "cell_type": "code", "execution_count": 3, "id": "738661a5-668f-4b2c-8823-dc3c0c92be94", "metadata": {}, "outputs": [], "source": [ "import json \n", "\n", "def get_geonames_city_state_data():\n", " geonames_file = \"../data/geonames-cities-states.json\"\n", " with open(geonames_file, 'r') as f:\n", " geonames_dict = json.load(f)\n", " \n", " \n", " cities_data = pd.DataFrame(geonames_dict['cities'])\\\n", " .rename(columns={'admin1_code': 'state_code', 'name': 'city_name', 'population': 'city_popln'})\n", " cities_data = cities_data[['id', 'state_code', 'city_name', 'city_popln', 'alternate_names']]\n", " states_data = pd.DataFrame(geonames_dict['states_by_abbr'].values())\\\n", " .rename(columns={'admin1_code': 'state_code', 'name': 'state_name'})\n", " states_data = states_data[['state_code', 'state_name']]\n", " city_states_data = cities_data.merge(states_data, how='left', on='state_code')\n", " city_states_data['city_weight'] = city_states_data['city_popln'] / city_states_data['city_popln'].sum()\n", " return city_states_data\n", "\n" ] }, { "cell_type": "code", "execution_count": 4, "id": "a3aeb4bd-2e84-4121-84b7-8ffb1118ca37", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "962\n" ] }, { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>id</th>\n", " <th>state_code</th>\n", " <th>city_name</th>\n", " <th>city_popln</th>\n", " <th>alternate_names</th>\n", " <th>state_name</th>\n", " <th>city_weight</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>4049979</td>\n", " <td>AL</td>\n", " <td>Birmingham</td>\n", " <td>212461</td>\n", " <td>[birmingham, bhm]</td>\n", " <td>Alabama</td>\n", " <td>0.001409</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>4050552</td>\n", " <td>TN</td>\n", " <td>Cordova</td>\n", " <td>68779</td>\n", " <td>[cordova]</td>\n", " <td>Tennessee</td>\n", " <td>0.000456</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>4058553</td>\n", " <td>AL</td>\n", " <td>Decatur</td>\n", " <td>55437</td>\n", " <td>[decatur, dcu]</td>\n", " <td>Alabama</td>\n", " <td>0.000368</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>4059102</td>\n", " <td>AL</td>\n", " <td>Dothan</td>\n", " <td>68567</td>\n", " <td>[dhn, dothan]</td>\n", " <td>Alabama</td>\n", " <td>0.000455</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>4067994</td>\n", " <td>AL</td>\n", " <td>Hoover</td>\n", " <td>84848</td>\n", " <td>[hoover]</td>\n", " <td>Alabama</td>\n", " <td>0.000563</td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>957</th>\n", " <td>11748973</td>\n", " <td>GA</td>\n", " <td>Stonecrest</td>\n", " <td>50000</td>\n", " <td>[stonecrest]</td>\n", " <td>Georgia</td>\n", " <td>0.000332</td>\n", " </tr>\n", " <tr>\n", " <th>958</th>\n", " <td>11838960</td>\n", " <td>CA</td>\n", " <td>Valley Glen</td>\n", " <td>60000</td>\n", " <td>[valley glen]</td>\n", " <td>California</td>\n", " <td>0.000398</td>\n", " </tr>\n", " <tr>\n", " <th>959</th>\n", " <td>11979227</td>\n", " <td>AZ</td>\n", " <td>Encanto</td>\n", " <td>54614</td>\n", " <td>[encanto]</td>\n", " <td>Arizona</td>\n", " <td>0.000362</td>\n", " </tr>\n", " <tr>\n", " <th>960</th>\n", " <td>11979238</td>\n", " <td>AZ</td>\n", " <td>Central City</td>\n", " <td>58161</td>\n", " <td>[central city]</td>\n", " <td>Arizona</td>\n", " <td>0.000386</td>\n", " </tr>\n", " <tr>\n", " <th>961</th>\n", " <td>12541728</td>\n", " <td>GA</td>\n", " <td>South Fulton</td>\n", " <td>107436</td>\n", " <td>[south fulton]</td>\n", " <td>Georgia</td>\n", " <td>0.000713</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>962 rows × 7 columns</p>\n", "</div>" ], "text/plain": [ " id state_code city_name city_popln alternate_names \\\n", "0 4049979 AL Birmingham 212461 [birmingham, bhm] \n", "1 4050552 TN Cordova 68779 [cordova] \n", "2 4058553 AL Decatur 55437 [decatur, dcu] \n", "3 4059102 AL Dothan 68567 [dhn, dothan] \n", "4 4067994 AL Hoover 84848 [hoover] \n", ".. ... ... ... ... ... \n", "957 11748973 GA Stonecrest 50000 [stonecrest] \n", "958 11838960 CA Valley Glen 60000 [valley glen] \n", "959 11979227 AZ Encanto 54614 [encanto] \n", "960 11979238 AZ Central City 58161 [central city] \n", "961 12541728 GA South Fulton 107436 [south fulton] \n", "\n", " state_name city_weight \n", "0 Alabama 0.001409 \n", "1 Tennessee 0.000456 \n", "2 Alabama 0.000368 \n", "3 Alabama 0.000455 \n", "4 Alabama 0.000563 \n", ".. ... ... \n", "957 Georgia 0.000332 \n", "958 California 0.000398 \n", "959 Arizona 0.000362 \n", "960 Arizona 0.000386 \n", "961 Georgia 0.000713 \n", "\n", "[962 rows x 7 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "city_states_data = get_geonames_city_state_data()\n", "print(len(city_states_data))\n", "city_states_data" ] }, { "cell_type": "code", "execution_count": 5, "id": "d35076ae-1d45-4699-8257-e98612500e43", "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>id</th>\n", " <th>state_code</th>\n", " <th>city_name</th>\n", " <th>city_popln</th>\n", " <th>alternate_names</th>\n", " <th>state_name</th>\n", " <th>city_weight</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>554</th>\n", " <td>5128581</td>\n", " <td>NY</td>\n", " <td>New York City</td>\n", " <td>8804190</td>\n", " <td>[the big apple, new york, new york city, nyc, ...</td>\n", " <td>New York</td>\n", " <td>0.058393</td>\n", " </tr>\n", " <tr>\n", " <th>718</th>\n", " <td>5368361</td>\n", " <td>CA</td>\n", " <td>Los Angeles</td>\n", " <td>3898747</td>\n", " <td>[los angeles, lax, l.a., la]</td>\n", " <td>California</td>\n", " <td>0.025858</td>\n", " </tr>\n", " <tr>\n", " <th>522</th>\n", " <td>5110302</td>\n", " <td>NY</td>\n", " <td>Brooklyn</td>\n", " <td>2736074</td>\n", " <td>[bk, borough of brooklyn, brooklyn]</td>\n", " <td>New York</td>\n", " <td>0.018147</td>\n", " </tr>\n", " <tr>\n", " <th>359</th>\n", " <td>4887398</td>\n", " <td>IL</td>\n", " <td>Chicago</td>\n", " <td>2696555</td>\n", " <td>[chi, chicago]</td>\n", " <td>Illinois</td>\n", " <td>0.017885</td>\n", " </tr>\n", " <tr>\n", " <th>266</th>\n", " <td>4699066</td>\n", " <td>TX</td>\n", " <td>Houston</td>\n", " <td>2304580</td>\n", " <td>[houston, hou]</td>\n", " <td>Texas</td>\n", " <td>0.015285</td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>260</th>\n", " <td>4692883</td>\n", " <td>TX</td>\n", " <td>Galveston</td>\n", " <td>50180</td>\n", " <td>[galveston, gls]</td>\n", " <td>Texas</td>\n", " <td>0.000333</td>\n", " </tr>\n", " <tr>\n", " <th>758</th>\n", " <td>5384690</td>\n", " <td>CA</td>\n", " <td>Poway</td>\n", " <td>50157</td>\n", " <td>[poway]</td>\n", " <td>California</td>\n", " <td>0.000333</td>\n", " </tr>\n", " <tr>\n", " <th>470</th>\n", " <td>5025264</td>\n", " <td>MN</td>\n", " <td>Edina</td>\n", " <td>50138</td>\n", " <td>[edina]</td>\n", " <td>Minnesota</td>\n", " <td>0.000333</td>\n", " </tr>\n", " <tr>\n", " <th>475</th>\n", " <td>5037790</td>\n", " <td>MN</td>\n", " <td>Minnetonka Mills</td>\n", " <td>50117</td>\n", " <td>[minnetonka mills]</td>\n", " <td>Minnesota</td>\n", " <td>0.000332</td>\n", " </tr>\n", " <tr>\n", " <th>957</th>\n", " <td>11748973</td>\n", " <td>GA</td>\n", " <td>Stonecrest</td>\n", " <td>50000</td>\n", " <td>[stonecrest]</td>\n", " <td>Georgia</td>\n", " <td>0.000332</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>962 rows × 7 columns</p>\n", "</div>" ], "text/plain": [ " id state_code city_name city_popln \\\n", "554 5128581 NY New York City 8804190 \n", "718 5368361 CA Los Angeles 3898747 \n", "522 5110302 NY Brooklyn 2736074 \n", "359 4887398 IL Chicago 2696555 \n", "266 4699066 TX Houston 2304580 \n", ".. ... ... ... ... \n", "260 4692883 TX Galveston 50180 \n", "758 5384690 CA Poway 50157 \n", "470 5025264 MN Edina 50138 \n", "475 5037790 MN Minnetonka Mills 50117 \n", "957 11748973 GA Stonecrest 50000 \n", "\n", " alternate_names state_name \\\n", "554 [the big apple, new york, new york city, nyc, ... New York \n", "718 [los angeles, lax, l.a., la] California \n", "522 [bk, borough of brooklyn, brooklyn] New York \n", "359 [chi, chicago] Illinois \n", "266 [houston, hou] Texas \n", ".. ... ... \n", "260 [galveston, gls] Texas \n", "758 [poway] California \n", "470 [edina] Minnesota \n", "475 [minnetonka mills] Minnesota \n", "957 [stonecrest] Georgia \n", "\n", " city_weight \n", "554 0.058393 \n", "718 0.025858 \n", "522 0.018147 \n", "359 0.017885 \n", "266 0.015285 \n", ".. ... \n", "260 0.000333 \n", "758 0.000333 \n", "470 0.000333 \n", "475 0.000332 \n", "957 0.000332 \n", "\n", "[962 rows x 7 columns]" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "city_states_data.sort_values('city_weight', ascending=False)" ] }, { "cell_type": "code", "execution_count": 6, "id": "ad8ca7e6-7511-42f2-92df-a02526637f23", "metadata": {}, "outputs": [], "source": [ "city_weights = city_states_data[['city_name', 'city_weight']].set_index('city_name').to_dict()['city_weight']\n", "# city_weights" ] }, { "cell_type": "code", "execution_count": 7, "id": "0eca185b-0ff3-4cad-a878-92f1d065081c", "metadata": {}, "outputs": [], "source": [ "city_info = city_states_data[['city_name', 'alternate_names']].set_index('city_name').to_dict()['alternate_names']\n", "state_info = city_states_data[['state_code', 'state_name']].set_index('state_code').to_dict()['state_name']\n", "city_state_code_info = city_states_data[['city_name', 'state_code', 'city_weight']].copy()\n", "city_state_name_info = city_states_data[['city_name', 'state_name', 'city_weight']].copy()" ] }, { "cell_type": "code", "execution_count": 8, "id": "9cc31eb8-b5eb-4daa-a466-f873db8e3038", "metadata": {}, "outputs": [], "source": [ "# city_info" ] }, { "cell_type": "code", "execution_count": 9, "id": "ac238a0e-7526-40d3-8606-42d65fda3bd9", "metadata": {}, "outputs": [], "source": [ "label_map = {\n", " 0: \"O\", # Outside any named entity\n", " 1: \"B-PER\", # Beginning of a person entity\n", " 2: \"I-PER\", # Inside a person entity\n", " 3: \"B-ORG\", # Beginning of an organization entity\n", " 4: \"I-ORG\", # Inside an organization entity\n", " 5: \"B-CITY\", # Beginning of a city entity\n", " 6: \"I-CITY\", # Inside a city entity\n", " 7: \"B-STATE\", # Beginning of a state entity\n", " 8: \"I-STATE\", # Inside a state entity\n", " 9: \"B-CITYSTATE\", # Beginning of a city_state entity\n", " 10: \"I-CITYSTATE\", # Inside a city_state entity\n", " }\n", "\n", "\n", "persons = [\n", " 'Donald Trump', 'John Smith', 'Roger Williams', 'Michelle Obama', 'Elon Musk',\n", " 'Barack Obama', 'Bill Gates', 'Steve Jobs', 'Warren Buffett', 'Oprah Winfrey',\n", " 'Jeff Bezos', 'Taylor Swift', 'Jennifer Lawrence', 'Brad Pitt', 'Leonardo DiCaprio',\n", " 'Katy Perry', 'Tom Hanks', 'Emma Watson', 'Johnny Depp', 'Scarlett Johansson',\n", " 'Mark Zuckerberg', 'Sheryl Sandberg', 'Ivanka Trump', 'Joe Biden', 'Kamala Harris',\n", " 'Serena Williams', 'Michael Jordan', 'LeBron James', 'Tiger Woods', 'Cristiano Ronaldo',\n", " 'Lionel Messi', 'Roger Federer', 'Usain Bolt', 'Simone Biles', 'Tom Brady',\n", " 'Peyton Manning', 'David Beckham', 'Rafael Nadal', 'Novak Djokovic', 'Andy Murray',\n", " 'George Clooney', 'Matt Damon', 'Julia Roberts', 'Angelina Jolie', 'Morgan Freeman',\n", " 'Chris Hemsworth', 'Dwayne Johnson', 'Vin Diesel', 'Keanu Reeves', 'Robert Downey Jr.',\n", " 'Chris Evans', 'Will Smith', 'Johnny Cash', 'Bob Dylan', 'Paul McCartney',\n", " 'Ringo Starr', 'John Lennon', 'George Harrison', 'Madonna', 'Prince',\n", " 'Bruce Springsteen', 'Elton John', 'David Bowie', 'Whitney Houston', 'Celine Dion',\n", " 'Marilyn Monroe', 'Audrey Hepburn', 'Albert Einstein', 'Isaac Newton', 'Marie Curie',\n", " 'Galileo Galilei', 'Nikola Tesla', 'Stephen Hawking', 'Richard Feynman', 'Carl Sagan',\n", " 'Neil Armstrong', 'Yuri Gagarin', 'Sally Ride', 'Jane Goodall', 'Charles Darwin',\n", " 'Mahatma Gandhi', 'Nelson Mandela', 'Martin Luther King Jr.', 'Malala Yousafzai', 'Angela Merkel',\n", " 'Theresa May', 'Vladimir Putin', 'Xi Jinping', 'Justin Trudeau', 'Jacinda Ardern',\n", " 'Pope Francis', 'Dalai Lama', 'Queen Elizabeth II', 'Prince William', 'Prince Harry',\n", " 'James Anderson', 'Michael Brown', 'David Clark', 'John Doe', 'Robert Evans',\n", " 'Christopher Foster', 'William Garcia', 'Charles Hall', 'Joseph Harris', 'Daniel Jackson',\n", " 'Matthew Johnson', 'George King', 'Anthony Lewis', 'Mark Miller', 'Paul Moore',\n", " 'Steven Nelson', 'Kevin Perry', 'Thomas Reed', 'Brian Roberts', 'Jason Scott',\n", " 'Andrew Smith', 'Joshua Thompson', 'Ryan Turner', 'Brandon Walker', 'Nicholas White',\n", " 'Jonathan Young', 'Adam Baker', 'Justin Carter', 'Benjamin Collins', 'Aaron Cook',\n", " 'Alexander Davis', 'Tyler Edwards', 'Zachary Fisher', 'Ethan Graham', 'Jacob Green',\n", " 'Austin Hernandez', 'Mason Hill', 'Logan Hughes', 'Owen Jenkins', 'Lucas Kelly',\n", " 'Nathan Lee', 'Caleb Long', 'Henry Martinez', 'Dylan Mitchell', 'Gabriel Morris',\n", " 'Jack Murphy', 'Connor Myers', 'Liam Parker', 'Isaac Patterson', 'Evan Phillips',\n", " 'Hunter Price', 'Noah Richardson', 'Samuel Rivera', 'Gavin Rogers', 'Aiden Ross',\n", " 'Christian Russell', 'Ian Sanders', 'Eli Simmons', 'Chase Stewart', 'Cameron Sullivan',\n", " 'Bryan Taylor', 'Cole Thomas', 'Jake Thompson', 'Luke Torres', 'Blake Turner',\n", " 'Jesse Ward', 'Joel Watson', 'Derek Williams', 'Mitchell Wright', 'Dustin Young',\n", " 'Megan Allen', 'Jennifer Bailey', 'Jessica Bennett', 'Emily Brooks', 'Sarah Campbell',\n", " 'Amanda Carter', 'Rebecca Collins', 'Samantha Cooper', 'Stephanie Diaz', 'Rachel Evans',\n", " 'Christine Flores', 'Laura Foster', 'Michelle Garcia', 'Amber Gonzales', 'Lisa Gray',\n", " 'Kimberly Green', 'Heather Harris', 'Tiffany Henderson', 'Natalie Hernandez', 'Crystal Hill',\n", " 'Victoria Hughes', 'Erica Jenkins', 'Nicole Johnson', 'Katherine Kelly', 'Danielle Lee',\n", " 'Hannah Lewis', 'Melissa Lopez', 'Patricia Martin', 'Brittany Moore', 'Brenda Morgan',\n", "\n", " ]\n", "organizations = [\n", " 'Google Inc.', 'Apple Inc.', 'Amazon.com', 'Facebook Inc.', 'Microsoft Corporation',\n", " 'Tesla Motors', 'Netflix Inc.', 'The New York Times', 'The Washington Post', 'Wall Street Journal',\n", " 'Intel Corporation', 'Oracle Corporation', 'IBM', 'Coca-Cola Company', 'PepsiCo',\n", " 'Starbucks', 'Walmart Inc.', 'Target Corporation', 'ExxonMobil', 'Shell Oil Company',\n", " 'Ford Motor Company', 'General Motors', 'Toyota Motor Corporation', 'Volkswagen Group', 'BMW Group',\n", " 'American Airlines', 'Delta Airlines', 'United Airlines', 'Boeing Company', 'Lockheed Martin',\n", " 'SpaceX', 'NASA', 'Harvard University', 'Stanford University', 'Massachusetts Institute of Technology',\n", " 'University of California, Berkeley', 'University of Oxford', 'University of Cambridge', 'Princeton University', 'Yale University',\n", " 'University of Chicago', 'Columbia University', 'Johns Hopkins University', 'University of Southern California', 'University of Michigan',\n", " 'Goldman Sachs', 'JPMorgan Chase', 'Citibank', 'Morgan Stanley', 'Bank of America',\n", " 'Deloitte', 'Ernst & Young', 'PricewaterhouseCoopers', 'KPMG', 'McKinsey & Company',\n", " 'Boston Consulting Group', 'Accenture', 'BlackRock', 'Fidelity Investments', 'Vanguard Group',\n", " 'Nike Inc.', 'Adidas', 'Under Armour', 'Patagonia', 'The Walt Disney Company',\n", " 'Time Warner', 'NBCUniversal', 'Sony Corporation', 'Warner Bros.', 'Paramount Pictures',\n", " 'Universal Music Group', 'Sony Music Entertainment', 'Warner Music Group', 'Pfizer Inc.', 'Johnson & Johnson',\n", " 'Novartis', 'Merck & Co.', 'GlaxoSmithKline', 'AstraZeneca', 'Moderna',\n", " 'New York City Hospital', 'Los Angeles County Library', 'San Francisco Community College',\n", " 'Miami International University', 'Chicago Regional Bank', 'Dallas Medical Center',\n", " 'Boston Tech Solutions', 'Atlanta City Bank', 'Seattle Software Hub', 'Phoenix Energy Solutions',\n", " 'Denver Financial Group', 'Houston General Hospital', 'Portland Health Services', 'Las Vegas Convention Center',\n", " 'San Diego Software Innovations', 'Philadelphia Law Firm', 'Orlando Realty Group',\n", " 'Austin Engineering Solutions', 'Cleveland City Schools', 'Detroit Manufacturing Hub',\n", " 'Baltimore Technology Inc.', 'Minneapolis Insurance Group', 'St. Louis Transportation Services',\n", " 'Tampa Healthcare Network', 'Pittsburgh Steelworks Corporation', 'Sacramento Business Ventures',\n", " 'Indianapolis Marketing Solutions', 'Columbus Financial Advisors', 'Fort Worth Electric Company',\n", " 'Charlotte Digital Marketing', 'Milwaukee Industrial Solutions', 'Memphis Logistics Services',\n", " 'Washington DC Development', 'Nashville Business Enterprises', 'Louisville Fitness Center',\n", " 'Kansas City Architectural Firm', 'Oklahoma City University', 'Virginia Beach Law Associates',\n", " 'Raleigh Research Institute', 'Salt Lake City Analytics', 'Richmond Financial Group',\n", " 'Newark Data Solutions', 'Anchorage Energy Solutions', 'Fresno Water Authority',\n", " 'Omaha Financial Services', 'Colorado Springs Health Institute', 'Mesa Auto Parts',\n", " 'Virginia Beach Shipping', 'Sacramento Community Center', 'Albuquerque Electronics Company',\n", " 'Tucson Data Science Center', 'Miami Lakes Software Solutions', 'Wichita Steel Corporation',\n", " 'Arlington Cybersecurity Group', 'Bakersfield Construction Services', 'Aurora Logistics Firm',\n", " 'Anaheim Technology Hub', 'Santa Ana Healthcare Services', 'Riverside Manufacturing Co.',\n", " 'St. Paul Medical Associates', 'Lexington University Hospital', 'Plano Technology Solutions',\n", " 'Lincoln Manufacturing Inc.', 'Greensboro Industrial Partners', 'Jersey City Financial Group',\n", " 'Chandler Electronics', 'Madison Biotechnology Solutions', 'Lubbock Medical Supplies',\n", " 'Scottsdale Real Estate Group', 'Reno Venture Capitalists', 'Henderson Engineering Consultants',\n", " 'Norfolk Health Services', 'Chesapeake Data Systems', 'Fremont Software Group',\n", " 'Irvine Legal Services', 'San Bernardino Logistics Group', 'Boise Energy Technologies',\n", " 'Spokane Steel Fabricators', 'Glendale Solar Power Corporation', 'Garland Medical Services',\n", " 'Hialeah Shipping and Logistics', 'Chesapeake Financial Advisors', 'Frisco Software Hub',\n", " 'McKinney Electronics Corporation', 'Gilbert Transportation Group', 'Baton Rouge Financial Services',\n", " 'Shreveport Data Analytics', 'Mobile Business Solutions', 'Huntsville Rocket Technologies',\n", " 'Knoxville Agricultural Partners', 'Dayton Software Innovations', 'Grand Rapids Healthcare Network',\n", " 'Fort Lauderdale Construction Group', 'Tempe Electric Vehicles', 'Winston-Salem Marketing Firm',\n", " 'Fayetteville Consulting Services', 'Springfield Realty Group', 'Yonkers Manufacturing Hub',\n", " 'Augusta Insurance Group', 'Salem Solar Energy Solutions', 'Pasadena Legal Consultants',\n", " 'Seattle Pacific University', 'San Diego Zoo', 'Portland Art Museum',\n", " 'Boston Medical Group', 'Chicago Tribune', 'Dallas Cowboys Football Club',\n", " 'Los Angeles Philharmonic Orchestra', 'New York University', 'Houston Community College',\n", " 'Phoenix Solar Power', 'Denver Public Library', 'Miami International Airport',\n", " 'Atlanta Symphony Orchestra', 'San Francisco Opera', 'Orlando City Soccer Club',\n", " 'Nashville Symphony', 'Baltimore Ravens Football Team', 'Cleveland Clinic',\n", " 'Pittsburgh Steelers Football Team', 'Detroit Institute of Arts',\n", " 'Tampa Bay Buccaneers Football Club', 'St. Louis Cardinals Baseball Team',\n", " 'Indianapolis Colts Football Team', 'Austin Film Society', 'Seattle Sounders Football Club',\n", " 'Minneapolis Institute of Art', 'Charlotte Hornets Basketball Club', 'Portland Trail Blazers Basketball Team',\n", " 'Las Vegas Convention and Visitors Authority', 'New Orleans Saints Football Club',\n", " 'San Antonio Spurs Basketball Club', 'Philadelphia Eagles Football Club',\n", " 'Kansas City Chiefs Football Team', 'Cincinnati Reds Baseball Club',\n", " 'Memphis Grizzlies Basketball Team', 'Washington Wizards Basketball Club',\n", " 'Milwaukee Bucks Basketball Club', 'Sacramento Kings Basketball Team',\n", " 'Salt Lake City Ballet', 'Boise State University', 'Albuquerque International Balloon Fiesta',\n", " 'Raleigh-Durham International Airport', 'Richmond Symphony', 'Fresno Pacific University',\n", " 'Spokane Transit Authority', 'Henderson Engineering', 'Mesa Public Schools',\n", " 'Scottsdale Museum of Contemporary Art', 'Chandler Regional Medical Center', 'Glendale Unified School District',\n", " 'Riverside Community Hospital', 'Aurora Public Schools', 'Anaheim Ducks Hockey Team',\n", " 'Santa Ana College', 'Stockton Unified School District', 'Irvine Company', 'San Bernardino Community College District',\n", " 'Modesto Junior College', 'Bakersfield Condors Hockey Team', 'Fresno State University',\n", " 'Chesapeake Energy Corporation', 'Omaha World-Herald', 'Tucson Medical Center',\n", " 'Virginia Beach Public Schools', 'Norfolk Naval Shipyard', 'Newark Beth Israel Medical Center',\n", " 'Fort Wayne Mad Ants Basketball Team', 'Fremont High School', 'Shreveport Regional Airport',\n", " 'Mobile Public Library', 'Huntsville Hospital', 'Knoxville Symphony Orchestra',\n", " 'Dayton International Airport', 'Grand Rapids Symphony', 'Winston-Salem Dash Baseball Team',\n", " 'Fayetteville Technical Community College', 'Springfield Cardinals Baseball Team',\n", " 'Augusta National Golf Club', 'Salem Health', 'Pasadena Playhouse', 'Yonkers Public Schools',\n", " 'Boulder Community Health', 'Naperville North High School', 'Lansing Community College',\n", " 'Reno-Tahoe International Airport', 'Columbia University Medical Center', 'Albany Law School',\n", " 'Buffalo Sabres Hockey Team', 'Syracuse University', 'Toledo Museum of Art', 'Akron Public Schools',\n", " 'Daytona International Speedway', 'Des Moines Public Library', 'Rochester Philharmonic Orchestra',\n", " 'Flint Institute of Arts', 'Lincoln Memorial University', 'Baton Rouge Community College',\n", " 'Chattanooga Symphony and Opera', 'Greenville Technical College', 'Cedar Rapids Opera Theatre',\n", " 'Pensacola Naval Air Station'\n", " ]\n", "\n", "products = [\n", " 'iPhone', 'Samsung Galaxy', 'MacBook', 'PlayStation 5', 'Nike shoes', \n", " 'AirPods', 'Xbox Series X', 'Canon DSLR', 'GoPro', 'Adidas sneakers', \n", " 'Fitbit', 'Google Pixel', 'Kindle', 'Bose headphones', 'Sony TV', \n", " 'Dyson vacuum', 'KitchenAid mixer', 'Surface Pro', 'Roomba', 'Apple Watch'\n", "]\n", "\n", "countries = [\n", " 'USA', 'France', 'Japan', 'Germany', 'Canada', \n", " 'Australia', 'Mexico', 'China', 'Brazil', 'India', \n", " 'Italy', 'Spain', 'South Korea', 'Russia', 'Netherlands', \n", " 'United Kingdom', 'Sweden', 'Norway', 'Switzerland', 'Argentina'\n", "]\n", "\n", "services = [\n", " 'Netflix', 'Spotify', 'Uber', 'Amazon Prime', 'Google Drive', \n", " 'Zoom', 'Dropbox', 'Slack', 'LinkedIn', 'Disney+', \n", " 'YouTube Premium', 'Venmo', 'DoorDash', 'Postmates', 'Hulu', \n", " 'Skype', 'Grubhub', 'Twitch', 'Instacart', 'Lyft'\n", "]\n", "\n", "cars = [\n", " 'Tesla Model S', 'Ford Mustang', 'Chevrolet Camaro', 'Toyota Corolla', 'Honda Civic', \n", " 'BMW 3 Series', 'Audi A4', 'Mercedes-Benz C-Class', 'Jeep Wrangler', 'Ford F-150', \n", " 'Hyundai Elantra', 'Mazda CX-5', 'Chevrolet Tahoe', 'Nissan Altima', 'Kia Sorento', \n", " 'Volkswagen Golf', 'Subaru Outback', 'Tesla Model 3', 'Dodge Charger', 'Volvo XC90'\n", "]\n", "\n", "gadgets = [\n", " 'smartwatch', 'Bluetooth headphones', 'fitness tracker', 'smart speaker', 'tablet', \n", " 'laptop', 'gaming mouse', 'wireless charger', 'VR headset', 'noise-canceling headphones', \n", " 'dashcam', 'e-reader', 'action camera', 'portable hard drive', 'gaming console', \n", " 'mechanical keyboard', '4K monitor', 'digital camera', 'portable power bank', 'USB-C hub'\n", "]\n", "\n", "stocks = [\n", " 'AAPL', 'GOOGL', 'AMZN', 'MSFT', 'TSLA', \n", " 'NFLX', 'FB', 'BABA', 'NVDA', 'JPM', \n", " 'V', 'PYPL', 'BRK.A', 'DIS', 'INTC', \n", " 'PFE', 'NKE', 'ORCL', 'VZ', 'BA'\n", "]\n", "\n", "moneys = [\n", " 'cryptocurrency', 'cash', 'PayPal', 'credit card', 'Bitcoin', \n", " 'Ethereum', 'bank transfer', 'wire transfer', 'Western Union', 'Venmo', \n", " 'debit card', 'Zelle', 'Apple Pay', 'Google Pay', 'Coinbase', \n", " 'Tether', 'Litecoin', 'Dogecoin', 'cash app', 'Ripple'\n", "]\n", "\n", "finances = [\n", " '401(k)', 'IRA', 'mutual funds', 'mortgage', 'student loan', \n", " 'savings account', 'retirement fund', 'bond', 'annuity', 'index fund', \n", " 'Roth IRA', 'tax-free savings account', 'pension', 'trust fund', 'hedge fund', \n", " 'credit score', 'auto loan', 'home equity loan', 'personal loan', 'debt consolidation'\n", "]\n", "\n", "travels = [\n", " 'flights', 'hotels', 'car rentals', 'vacation packages', 'cruise trips', \n", " 'road trips', 'train tickets', 'adventure tours', 'guided tours', 'backpacking trips',\n", " 'honeymoon destinations', 'beach resorts', 'luxury travel', 'budget travel', 'camping gear', \n", " 'family vacations', 'ski trips', 'all-inclusive resorts', 'last-minute deals', 'travel insurance'\n", "]\n", "\n", "foods = [\n", " 'pizza', 'sushi', 'burgers', 'pasta', 'salads', \n", " 'vegan food', 'barbecue', 'fried chicken', 'ramen', 'tacos', \n", " 'sandwiches', 'noodles', 'soups', 'cakes', 'ice cream', \n", " 'steak', 'seafood', 'breakfast food', 'brunch', 'desserts'\n", "]\n", "\n", "restaurants = [\n", " 'Italian restaurants', 'Mexican restaurants', 'Japanese restaurants', 'Chinese restaurants', 'Indian restaurants', \n", " 'fast food chains', 'fine dining', 'vegan restaurants', 'steakhouses', 'seafood restaurants', \n", " 'barbecue joints', 'sushi bars', 'cafes', 'pizzerias', 'buffet restaurants', \n", " 'food trucks', 'family-friendly restaurants', 'gastropubs', 'brunch spots', 'diner'\n", "]\n", "\n", "## Additional partial terms\n", "sports_terms_missing = [\n", " \"footbal\", \"baske\", \"socce\", \"golf\", \"cricke\", \"rugby\", \"hocke\", \"tenni\", \n", " \"swimmin\", \"athleti\", \"fishi\", \"basebal\", \"volleybal\", \"badminto\", \"maratho\", \n", " \"skatin\", \"climbin\", \"racquetball\", \"bowlin\", \"darts\", \"gymnasti\", \"bikin\"\n", "]\n", "\n", "locations_and_landmarks = [\n", " \"statue\", \"museum\", \"plaza\", \"zoo\", \"church\", \"theater\", \"stadium\", \"mountain\", \n", " \"park\", \"lake\", \"beach\", \"river\", \"palace\", \"cathedra\", \"mansion\", \"monument\", \n", " \"temple\", \"observato\", \"canyon\", \"garden\", \"conservato\", \"boardwal\", \"forest\", \n", " \"pier\", \"lighthouse\", \"arena\"\n", "]\n", "\n", "activities_and_events = [\n", " \"conc\", \"exhib\", \"meet\", \"parad\", \"festi\", \"tourn\", \"game\", \"sho\", \"even\", \n", " \"gala\", \"confere\", \"seminar\", \"webina\", \"worksho\", \"lectur\", \"symposiu\", \n", " \"screenin\", \"rall\", \"celebratio\", \"ceremon\", \"get-togethe\", \"perfor\", \n", " \"gatherin\", \"competitio\", \"maratho\", \"speec\", \"workout\", \"showcas\"\n", "]\n", "\n", "food_missing = [\n", " \"sush\", \"pizz\", \"ramen\", \"bbq\", \"vega\", \"steak\", \"taco\", \"burg\", \"pasta\", \n", " \"brunc\", \"desse\", \"drink\", \"grill\", \"bake\", \"buffet\", \"sandwich\", \"noodle\", \n", " \"cafe\", \"taver\", \"gastro\", \"bistro\", \"del\", \"saloo\", \"barbecue\", \"snack\", \n", " \"confectio\", \"pub\"\n", "]\n", "\n", "transport_and_directions = [\n", " \"direc\", \"map\", \"bus\", \"train\", \"car\", \"park\", \"taxi\", \"subwa\", \"fly\", \n", " \"plane\", \"ticke\", \"pass\", \"ferr\", \"bicycl\", \"scoote\", \"shuttl\", \"walkin\", \n", " \"rideshar\", \"transi\", \"toll\", \"metr\", \"road\", \"route\", \"stop\"\n", "]\n" ] }, { "cell_type": "code", "execution_count": 10, "id": "2a166f90-7262-4047-92e4-f83a18c7c5d4", "metadata": {}, "outputs": [], "source": [ "\n", "def get_sample_from_cities(city_info, city_weights, actual_threshold=0.7):\n", " cities = list(city_info.keys())\n", " weights = [city_weights[city] for city in cities]\n", " city_random = random.choices(cities, weights=weights, k=1)[0]\n", " rand_val = random.random()\n", " if rand_val <= actual_threshold:\n", " return city_random\n", " return random.choice(city_info[city_random])\n", "\n", "def get_sample_from_states(state_info, actual_threshold=0.5):\n", " states = list(state_info.keys())\n", " state_random = random.choice(states)\n", " rand_val = random.random()\n", " if rand_val <= actual_threshold:\n", " return state_random\n", " return random.choice([state_info[state_random]])\n", "\n", "def get_sample_from_cities_and_states(city_state_code_info, city_state_name_info, state_code_threshold=0.8, comma_threshold=0.6):\n", " rand_val = random.random()\n", " if rand_val <= state_code_threshold:\n", " if rand_val <= comma_threshold:\n", " return ', '.join(city_state_code_info.sample(1, weights='city_weight', replace=True)[['city_name', 'state_code']].values.tolist()[0])\n", " else:\n", " return ' '.join(city_state_code_info.sample(1, weights='city_weight', replace=True)[['city_name', 'state_code']].values.tolist()[0])\n", " return ', '.join(city_state_name_info.sample(1, weights='city_weight', replace=True)[['city_name', 'state_name']].values.tolist()[0])\n", "\n", "def get_random_choice_from_list(choices_list):\n", " return random.choice(choices_list)\n", " \n" ] }, { "cell_type": "code", "execution_count": 11, "id": "a4f1fd1a-cd74-4ebe-b9aa-c3cd07dcf2bb", "metadata": {}, "outputs": [], "source": [ "# for _ in range(100):\n", "# print(get_sample_from_cities_and_states(city_state_code_info, city_state_name_info, state_code_threshold=0.8))" ] }, { "cell_type": "code", "execution_count": 12, "id": "cb760f94-1dba-4ec3-816f-7cfb3bb9b7b0", "metadata": {}, "outputs": [], "source": [ "templates = [\n", " # Simple City-Based Queries\n", " \"weather {city}\",\n", " \"{city} temperature\",\n", " \"sushi {city}\",\n", " \"ramen {city}\",\n", " \"pizza {city}\",\n", " \"plumber {city}\",\n", " \"electrician {city}\",\n", " \"roof repair {city}\",\n", " \"physio therapy {city}\",\n", " \"hospital {city}\",\n", " \"doctor {city}\",\n", " \"nurse {city}\",\n", " \"home improvement {city}\",\n", " \"home services {city}\",\n", " \"weather forecast {city}\",\n", " \"current weather {city}\",\n", " \"best restaurants {city}\",\n", " \"top yelp reviews {city}\",\n", " \"places to visit in {city}\",\n", " \"best cafes in {city}\",\n", " \"emergency services {city}\",\n", " \"gyms in {city}\",\n", " \"car repair {city}\",\n", " \"florist {city}\",\n", " \"lawyers in {city}\",\n", " \"real estate agents {city}\",\n", " \"hiking trails {city}\",\n", " \"parks in {city}\",\n", " \"movie theaters {city}\",\n", " \"top hotels in {city}\",\n", " \"events in {city} this weekend\",\n", " \"pharmacies {city}\",\n", "\n", " # State-Based Queries\n", " \"home services in {state}\",\n", " \"best restaurants in {state}\",\n", " \"real estate agents {state}\",\n", " \"roof repair services {state}\",\n", " \"hospitals in {state}\",\n", " \"weather {state}\",\n", " \"temperature {state}\",\n", " \"physio therapy {state}\",\n", " \"doctors in {state}\",\n", " \"top-rated plumbers {state}\",\n", " \"electricians {state}\",\n", " \"emergency services {state}\",\n", " \"sushi {state}\",\n", " \"ramen {state}\",\n", " \"pizza {state}\",\n", " \"parks in {state}\",\n", " \"hiking trails {state}\",\n", " \"pharmacies in {state}\",\n", " \"best cafes {state}\",\n", " \"movie theaters {state}\",\n", "\n", " # City-State Combination Queries (Now using {city_state})\n", " \"weather {city_state}\",\n", " \"{city_state} temperature\",\n", " \"sushi {city_state}\",\n", " \"plumber {city_state}\",\n", " \"best restaurants in {city_state}\",\n", " \"top-rated roof repair {city_state}\",\n", " \"hospital {city_state}\",\n", " \"physio therapy {city_state}\",\n", " \"doctor {city_state}\",\n", " \"events in {city_state} this weekend\",\n", " \"lawyers in {city_state}\",\n", " \"home improvement services {city_state}\",\n", " \"florist {city_state}\",\n", " \"best cafes in {city_state}\",\n", " \"parks in {city_state}\",\n", " \"movie theaters {city_state}\",\n", " \"top hotels in {city_state}\",\n", " \"emergency services {city_state}\",\n", " \"car repair {city_state}\",\n", " \"pharmacies {city_state}\",\n", "\n", " \"sushi {city_state}\",\n", " \"ramen {city_state}\",\n", " \"pizza {city_state}\",\n", " \"parks {city_state}\",\n", " \"hiking trails {city_state}\",\n", " \"pharmacies {city_state}\",\n", " \"best cafes {city_state}\",\n", " \"movie theaters {city_state}\",\n", " \"hamburgers {city_state}\",\n", " \"burgers {city_state}\",\n", " \"pasta {city_state}\",\n", " \"salads {city_state}\",\n", " \"vegan food {city_state}\",\n", " \"fried chicken {city_state}\",\n", " \"ramen {city_state}\",\n", " \"tacos {city_state}\",\n", " \"sandwiches {city_state}\",\n", " \"noodles {city_state}\",\n", " \"soups {city_state}\",\n", " \"cakes {city_state}\",\n", " \"ice cream {city_state}\",\n", " \"steak {city_state}\",\n", " \"seafood {city_state}\",\n", " \"breakfast food {city_state}\",\n", " \"brunch {city_state}\",\n", " \"desserts {city_state}\",\n", " \n", " # CITY state order swapped\n", " \"{city_state} sushi\",\n", " \"{city_state} ramen\",\n", " \"{city_state} pizza\",\n", " \"{city_state} parks\",\n", " \"{city_state} hiking trails\",\n", " \"{city_state} pharmacies\",\n", " \"{city_state} best cafes\",\n", " \"{city_state} movie theaters\",\n", " \"{city_state} hamburgers\",\n", " \"{city_state} burgers\",\n", " \"{city_state} pasta\",\n", " \"{city_state} salads\",\n", " \"{city_state} vegan food\",\n", " \"{city_state} fried chicken\",\n", " \"{city_state} ramen\",\n", " \"{city_state} tacos\",\n", " \"{city_state} sandwiches\",\n", " \"{city_state} noodles\",\n", " \"{city_state} soups\",\n", " \"{city_state} cakes\",\n", " \"{city_state} ice cream\",\n", " \"{city_state} steak\",\n", " \"{city_state} seafood\",\n", " \"{city_state} breakfast food\",\n", " \"{city_state} brunch\",\n", " \"{city_state} desserts\",\n", " \n", " # Organization-Based Queries\n", " \"{organization} in {city_state}\",\n", " \"contact {organization} in {city}\",\n", " \"locations of {organization} in {state}\",\n", " \"does {organization} provide home repair services in {city}?\",\n", " \"can I book a doctor appointment at {organization} in {state}?\",\n", " \"does {organization} offer roof repair in {city_state}?\",\n", " \"hours of {organization} in {city}\",\n", " \"{organization} reviews in {state}\",\n", " \"best rated {organization} in {city_state}\",\n", " \"nearest branch of {organization} in {city}\",\n", " \n", " # Person-Based Queries\n", " \"Where is {person} hosting an event?\",\n", " \"Can I meet {person} in {city_state}?\",\n", " \"Is {person} available for an appointment in {city}?\",\n", " \"Is {person} traveling to {state} next week?\",\n", " \"Does {person} have a speech in {city_state}?\",\n", " \n", " # Mixed and Specialized Queries\n", " \"roof repair near {city}\",\n", " \"best sushi in {city_state}\",\n", " \"what's the weather forecast for {city}?\",\n", " \"who are the top doctors in {city_state}?\",\n", " \"restaurants near {city} with good reviews\",\n", " \"plumbing services in {city_state}\",\n", " \"upcoming events in {city} this weekend\",\n", " \"find hiking trails in {city_state}\",\n", " \"local electricians in {city_state}\",\n", " \"ramen places in {city}\",\n", " \"home improvement contractors near {city_state}\",\n", " \"best pizza near {city}\",\n", " \"does {organization} operate in {city_state}?\",\n", " \"find top-rated hospitals in {city_state}\",\n", " \"home maintenance services in {city_state}\",\n", " \"weather forecast for {city} this weekend\",\n", " \"roof repair specialists in {city}\",\n", " \"top-rated movie theaters in {city_state}\",\n", "\n", " # City-State Queries\n", " \"Best {restaurant} in {city_state}\",\n", " \"Top-rated {restaurant} in {city_state}\",\n", " \"Affordable {restaurant} in {city_state}\",\n", " \"Where to find the best {food} in {city_state}?\",\n", " \"Popular {food} places in {city_state}\",\n", " \"Top destinations for {travel} in {city_state}\",\n", " \"Best deals on {travel} in {city_state}\",\n", " \"Where to eat {food} in {city_state}?\",\n", " \"What are the most famous {restaurant} in {city_state}?\",\n", " \"Top {food} restaurants in {city_state} this weekend\",\n", "\n", " # Non-City/State Queries\n", " \"Best {restaurant} in the country\",\n", " \"Where to find the best {food} near me?\",\n", " \"Top destinations for {travel} this summer\",\n", " \"Best deals on {travel} packages\",\n", " \"Where to find cheap {travel} options?\",\n", " \"Popular {food} dishes in the USA\",\n", " \"Best {restaurant} chains in the country\",\n", " \"What are the healthiest {food} options?\",\n", " \"How to book affordable {travel} for families?\",\n", " \"Most popular {restaurant} for takeout\",\n", "\n", " # Additional Templates\n", " \"What is the best {food} to eat for dinner?\",\n", " \"Where to order {food} online?\",\n", " \"Best {restaurant} for date night\",\n", " \"Top {travel} websites for booking vacations\",\n", " \"Where to find {restaurant} reviews?\",\n", " \"What are the top-rated {travel} apps?\",\n", " \"Best {restaurant} near tourist attractions\",\n", " \"What is the most popular {food} in the USA?\",\n", " \"Best deals on {travel} for students\",\n", " \"Top {restaurant} for family gatherings\",\n", " \"Most affordable {food} delivery services\",\n", " \"What are the best {travel} insurance options?\",\n", " \"How to find luxury {restaurant} reservations\",\n", " \"Where to get authentic {food} near me?\",\n", " \"Top {restaurant} for business lunches\",\n", " \"How to plan a {travel} adventure?\",\n", " \"Best {restaurant} for weekend brunch\",\n", " \"What are the most popular {food} trends?\",\n", " \"Best {restaurant} for a large group\",\n", " \"How to get discounts on {travel} bookings?\"\n", "\n", " # Product-Based Queries\n", " \"Where to buy {product} online?\",\n", " \"Best deals on {product}\",\n", " \"How to repair a {product}?\",\n", " \"Latest reviews of {product}\",\n", " \"When will the next {product} be released?\",\n", " \"Top features of {product}\",\n", " \"Is {product} worth buying in 2024?\",\n", " \"User reviews of {product}\",\n", " \"Alternatives to {product}\",\n", " \"What is the price of {product}?\",\n", "\n", " # Country-Based Queries\n", " \"How to travel to {country}?\",\n", " \"Best tourist destinations in {country}\",\n", " \"Top hotels to stay in {country}\",\n", " \"Do I need a visa to visit {country}?\",\n", " \"Cultural traditions in {country}\",\n", " \"What is the official language of {country}?\",\n", " \"How to do business in {country}?\",\n", " \"What are the top exports of {country}?\",\n", " \"Current political situation in {country}\",\n", " \"Famous landmarks in {country}\",\n", "\n", " # Service-Based Queries\n", " \"How to cancel my {service} subscription?\",\n", " \"Is {service} worth the price?\",\n", " \"How does {service} compare to competitors?\",\n", " \"User reviews of {service}\",\n", " \"How to get a discount on {service}?\",\n", " \"What are the benefits of {service}?\",\n", " \"Best alternatives to {service}\",\n", " \"How to troubleshoot issues with {service}?\",\n", " \"Does {service} have a free trial?\",\n", " \"Is {service} available internationally?\",\n", "\n", " # Cars-Based Queries\n", " \"What is the top speed of {car}?\",\n", " \"User reviews of {car}\",\n", " \"How to finance a {car}?\",\n", " \"Fuel efficiency of {car}\",\n", " \"How to buy a second-hand {car}?\",\n", " \"What are the safety features of {car}?\",\n", " \"Maintenance costs of owning a {car}\",\n", " \"What is the resale value of {car}?\",\n", " \"Is {car} electric or gas-powered?\",\n", " \"Best upgrades for {car}\",\n", "\n", " # Gadgets-Based Queries\n", " \"What are the best apps for {gadget}?\",\n", " \"How to set up a {gadget}?\",\n", " \"User reviews of {gadget}\",\n", " \"Best accessories for {gadget}\",\n", " \"What are the health benefits of using a {gadget}?\",\n", " \"What is the battery life of {gadget}?\",\n", " \"How to sync {gadget} with my phone?\",\n", " \"Alternatives to {gadget}\",\n", " \"What are the best productivity apps for {gadget}?\",\n", " \"Is {gadget} waterproof?\",\n", "\n", " # Stocks-Based Queries\n", " \"What is the latest price of {stock}?\",\n", " \"How to buy shares of {stock}?\",\n", " \"Is {stock} a good investment in 2024?\",\n", " \"What are analysts saying about {stock}?\",\n", " \"Current stock performance of {stock}\",\n", " \"What is the market cap of {stock}?\",\n", " \"How to invest in {stock}?\",\n", " \"Latest earnings report of {stock}\",\n", " \"What are the dividend yields of {stock}?\",\n", " \"How to trade {stock} on the stock market?\",\n", "\n", " # Money-Based Queries\n", " \"How to convert {money} to another currency?\",\n", " \"Best ways to transfer {money} internationally\",\n", " \"What are the risks of using {money}?\",\n", " \"How to save {money} for the future?\",\n", " \"What is the best way to invest {money}?\",\n", " \"How to protect {money} from fraud?\",\n", " \"What are the fees for using {money}?\",\n", " \"Is {money} safe for online transactions?\",\n", " \"Best apps for managing {money}\",\n", " \"How to track spending with {money}?\",\n", "\n", " # Finance-Based Queries\n", " \"How to invest in a {finance}?\",\n", " \"What are the benefits of having a {finance}?\",\n", " \"How to calculate the returns on {finance}?\",\n", " \"What are the risks of investing in {finance}?\",\n", " \"How to get advice for managing my {finance}?\",\n", " \"How to apply for a {finance}?\",\n", " \"What are the tax benefits of {finance}?\",\n", " \"What are the best options for a {finance}?\",\n", " \"How to open a {finance} account?\",\n", " \"What is the interest rate on {finance}?\",\n", "\n", " # sports_term, location_and_landmark, activity_and_event, food_m, transport_and_direction\n", " # incomplete or misspelled sport/activity names\n", " \"{sports_term} near me\", \n", " \"find {sports_term}\", \n", " \"{sports_term} schedule\", \n", " \"{sports_term} news\", \n", " \"book {sports_term} tickets\", \n", " \"{sports_term} team\", \n", " \"{sports_term} game time\", \n", " \"when is the {sports_term} game\", \n", " \"top {sports_term} players\", \n", " \"local {sports_term} clubs\", \n", " \"where to play {sports_term}\", \n", " \"best {sports_term} venues\", \n", " \"{sports_term} tournament\",\n", "\n", " # Generic landmarks and location queries\n", " \"{location_and_landmark} nearby\", \n", " \"famous {location_and_landmark}\", \n", " \"{location_and_landmark} open now\", \n", " \"visit {location_and_landmark}\", \n", " \"{location_and_landmark} directions\", \n", " \"how to get to {location_and_landmark}\", \n", " \"nearest {location_and_landmark}\", \n", " \"{location_and_landmark} address\", \n", " \"top-rated {location_and_landmark}\", \n", " \"{location_and_landmark} hours\", \n", " \"find {location_and_landmark} near me\", \n", " \"{location_and_landmark} entry fee\", \n", " \"best {location_and_landmark} in {city}\",\n", "\n", " # Food and dining queries\n", " \"{food_m} place\", \n", " \"find {food_m}\", \n", " \"best {food_m} spot\", \n", " \"{food_m} delivery\", \n", " \"{food_m} open near me\", \n", " \"order {food_m}\", \n", " \"{food_m} deals\", \n", " \"{food_m} options\", \n", " \"{food_m} near me\", \n", " \"{food_m} reservation\", \n", " \"top-rated {food_m} restaurants\", \n", " \"{food_m} reviews\", \n", " \"{food_m} menu\", \n", " \"popular {food_m} dishes\", \n", " \"where to eat {food_m}\",\n", "\n", " # activities_and_events\n", " \"{activity_and_event} tickets\", \n", " \"nearest {activity_and_event}\", \n", " \"{activity_and_event} today\", \n", " \"upcoming {activity_and_event}\", \n", " \"book {activity_and_event}\", \n", " \"{activity_and_event} in {city}\", \n", " \"find {activity_and_event}\", \n", " \"{activity_and_event} schedule\", \n", " \"{activity_and_event} near me\", \n", " \"top-rated {activity_and_event} venues\", \n", " \"{activity_and_event} details\", \n", " \"how to attend {activity_and_event}\", \n", " \"{activity_and_event} location\", \n", " \"{activity_and_event} opening hours\",\n", "\n", " # Single-word incomplete or ambiguous queries (standalone)\n", " # Sports and Games (single or incomplete)\n", " \"footbal\", \"baske\", \"golf\", \"sush\", \"pizz\", \"zoo\", \"conc\", \"direc\", \n", " \"theate\", \"stadiu\", \"brunc\", \"tourn\", \"parad\", \"swimmin\", \"train\", \"taxi\", \n", " \"game\", \"meet\", \"mountain\", \"beac\", \"lake\", \"forest\", \"ligh\", \"restauran\", \n", " \"parki\", \"stor\", \"monumen\", \"aren\", \"boardwal\",\n", " # Locations and Landmarks (single or incomplete)\n", " \"statue\", \"museum\", \"plaza\", \"zoo\", \"church\", \"theater\", \"stadium\", \"mountain\", \n", " \"park\", \"lake\", \"beach\", \"river\", \"palace\", \"cathedra\", \"mansion\", \"monument\", \n", " \"temple\", \"observato\", \"canyon\", \"garden\", \"conservato\", \"boardwal\", \"forest\", \n", " \"pier\", \"lighthouse\", \"arena\", \"campgroun\", \"arch\", \"reservoi\", \"dam\", \"fountai\", \n", " \"waterfal\", \"galleri\", \"amphitheate\", \"sculptur\", \"trail\", \"cliff\", \"tower\", \"islan\",\n", " # Activities and Events (single or incomplete)\n", " \"conc\", \"exhib\", \"meet\", \"parad\", \"festi\", \"tourn\", \"game\", \"sho\", \"even\", \"gala\", \n", " \"confere\", \"seminar\", \"webina\", \"worksho\", \"lectur\", \"symposiu\", \"screenin\", \n", " \"rall\", \"celebratio\", \"ceremon\", \"get-togethe\", \"perfor\", \"gatherin\", \"competitio\", \n", " \"maratho\", \"speec\", \"workout\", \"exercis\", \"demonstratio\", \"ceremony\", \"readin\", \n", " \"daytrip\", \"lectur\", \"social\", \"activit\", \"performanc\", \"worksho\", \"openin\", \n", " \"finale\", \"comedy\", \"poetr\", \"talent\", \"match\",\n", " # Restaurants and Food Types (single or incomplete)\n", " \"sush\", \"pizz\", \"ramen\", \"bbq\", \"vega\", \"steak\", \"taco\", \"burg\", \"pasta\", \"brunc\", \n", " \"desse\", \"drink\", \"grill\", \"bake\", \"buffet\", \"sandwich\", \"noodle\", \"cafe\", \n", " \"taver\", \"gastro\", \"bistro\", \"deli\", \"saloo\", \"barbecue\", \"snack\", \"confectio\", \n", " \"pub\", \"salad\", \"cuisine\", \"fries\", \"wings\", \"pantr\", \"meatbal\", \"sub\", \"omel\", \n", " \"crepe\", \"wrap\", \"beverag\", \"dessert\", \"smoothie\", \"juice\", \"shake\", \"frappe\", \"coffee\",\n", " # Transport and Directions (single or incomplete)\n", " \"direc\", \"map\", \"bus\", \"train\", \"car\", \"park\", \"taxi\", \"subwa\", \"fly\", \"plane\", \n", " \"ticke\", \"pass\", \"ferr\", \"bicycl\", \"scoote\", \"shuttl\", \"walkin\", \"rideshar\", \n", " \"transi\", \"toll\", \"metr\", \"road\", \"route\", \"stop\", \"junctio\", \"termina\", \"highwa\", \n", " \"pathwa\", \"drivewa\", \"loop\", \"intersectio\", \"trailhead\", \"tub\", \"sidestro\", \n", " \"crosswal\", \"rout\", \"navigatio\", \"crossing\", \"pave\", \"deck\", \"lane\",\n", " # Technology and Gadgets (single or incomplete)\n", " \"lapt\", \"smartphon\", \"comput\", \"tablet\", \"earbuds\", \"bluetooth\", \"charg\", \"cabl\", \n", " \"headset\", \"monitor\", \"consol\", \"keyboard\", \"drive\", \"storag\", \"gaming\", \"mouse\", \n", " \"projector\", \"flashdriv\", \"powerban\", \"adapter\", \"webcam\", \"router\", \"modem\", \n", " \"camcorder\", \"printer\", \"copier\", \"recorde\", \"remote\", \"surge\", \"extend\", \"plug\", \n", " \"portabl\", \"backu\", \"networ\", \"recharge\", \"uplo\", \"downlo\", \"strea\", \"screencas\", \n", " \"googl\", \"apple\", \"micros\", \"andr\",\n", "]" ] }, { "cell_type": "code", "execution_count": 13, "id": "cc3a0477-6ae4-4794-8fef-b562d79dbbe9", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "570" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(templates)" ] }, { "cell_type": "code", "execution_count": 33, "id": "ad319a0f-7d5c-4e97-8632-c23007424ae0", "metadata": {}, "outputs": [], "source": [ "PERSON_ENTITY = \"{person}\"\n", "ORG_ENTITY = \"{organization}\"\n", "CITY_ENTITY = \"{city}\"\n", "STATE_ENTITY = \"{state}\"\n", "CITY_STATE_ENTITY = \"{city_state}\"\n", "PRODUCT_ENTITY = \"{product}\"\n", "COUNTRY_ENTITY = \"{country}\"\n", "SERVICE_ENTITY = \"{services}\"\n", "CAR_ENTITY = \"{car}\"\n", "GADGET_ENTITY = \"{gadget}\"\n", "STOCK_ENTITY = \"{stock}\"\n", "MONEY_ENTITY = \"{money}\"\n", "FINANCE_ENTITY = \"{finance}\"\n", "TRAVEL_ENTITY = \"{travel}\"\n", "FOOD_ENTITY = \"{food}\"\n", "RESTAURANT_ENTITY = \"{restaurant}\"\n", "SPORTS_TERMS_MISSING_ENTITY = \"{sports_term}\"\n", "LOCATIONS_AND_LANDMARKS_ENTITY = \"{location_and_landmark}\"\n", "ACTIVTIES_AND_EVENTS_ENTITY = \"{activity_and_event}\"\n", "FOOD_MISSING_ENTITY = \"{food_m}\"\n", "TRANSPORT_AND_DIRECTIONS_ENTITY = \"{transport_and_direction}\"\n", "\n", "\n", "def detect_entity(entity_name, template):\n", " return entity_name in template\n", "\n", "def tokenize(text):\n", " # Use regular expression to split words while keeping punctuation as separate tokens\n", " return re.findall(r'\\w+|[^\\w\\s]', text)\n", "\n", "# Tokenize the query and generate corresponding NER labels\n", "def tokenize_and_label(query, city, state, city_state, organization, person):\n", " tokens = tokenize(query) # Tokenize the query using the improved function\n", " ner_labels = [0] * len(tokens) # Initialize all labels as \"O\" (outside any entity)\n", " \n", " # Label city_state entity\n", " if city_state:\n", " city_state_tokens = tokenize(city_state)\n", " start_idx = find_token_index(tokens, city_state_tokens)\n", " if start_idx is not None:\n", " ner_labels[start_idx] = 9 # CSB-LOC (beginning of city_state)\n", " for i in range(1, len(city_state_tokens)):\n", " ner_labels[start_idx + i] = 10 # CSI-LOC (inside city_state)\n", "\n", " # Label city entity\n", " if city:\n", " city_tokens = tokenize(city)\n", " start_idx = find_token_index(tokens, city_tokens)\n", " if start_idx is not None:\n", " ner_labels[start_idx] = 5 # CB-LOC (beginning of city)\n", " for i in range(1, len(city_tokens)):\n", " ner_labels[start_idx + i] = 6 # CI-LOC (inside city)\n", " \n", " # Label state entity\n", " if state:\n", " state_tokens = tokenize(state)\n", " start_idx = find_token_index(tokens, state_tokens)\n", " if start_idx is not None:\n", " ner_labels[start_idx] = 7 # SB-LOC (beginning of state)\n", " for i in range(1, len(state_tokens)):\n", " ner_labels[start_idx + i] = 8 # SI-LOC (inside state)\n", "\n", " # Label organization entity\n", " if organization:\n", " org_tokens = tokenize(organization)\n", " start_idx = find_token_index(tokens, org_tokens)\n", " if start_idx is not None:\n", " ner_labels[start_idx] = 3 # B-ORG (beginning of organization)\n", " for i in range(1, len(org_tokens)):\n", " ner_labels[start_idx + i] = 4 # I-ORG (inside organization)\n", "\n", " # Label person entity\n", " if person:\n", " person_tokens = tokenize(person)\n", " start_idx = find_token_index(tokens, person_tokens)\n", " if start_idx is not None:\n", " ner_labels[start_idx] = 1 # B-PER (beginning of person)\n", " for i in range(1, len(person_tokens)):\n", " ner_labels[start_idx + i] = 2 # I-PER (inside person)\n", " \n", " return tokens, ner_labels\n", "\n", "# Function to find the starting index of an entity's tokens in the query tokens\n", "def find_token_index(tokens, entity_tokens):\n", " for i in range(len(tokens) - len(entity_tokens) + 1):\n", " if tokens[i:i + len(entity_tokens)] == entity_tokens:\n", " return i\n", " return None\n", "\n", "def generate_queries(templates, n_queries=10000):\n", " cnt = 0\n", " queries_with_labels = []\n", " query_counter = Counter()\n", " while cnt < n_queries:\n", " if (cnt %10000) == 0:\n", " print(f\"complted generating {cnt} queries\")\n", " template = random.choice(templates)\n", " # print(template)\n", " person, organization, city, state, city_state = (None,) * 5\n", " product, country, service, car, gadget, stock, money, finance, travel, food, restaurant = (None,) * 11\n", " sports_term, location_and_landmark, activity_and_event, food_m, transport_and_direction = (None,) * 5\n", "\n", " if detect_entity(PERSON_ENTITY, template):\n", " person=get_random_choice_from_list(persons)\n", " if detect_entity(ORG_ENTITY, template):\n", " organization = get_random_choice_from_list(organizations)\n", " if detect_entity(PRODUCT_ENTITY, template):\n", " product = get_random_choice_from_list(products)\n", " if detect_entity(COUNTRY_ENTITY, template):\n", " country = get_random_choice_from_list(countries)\n", " if detect_entity(COUNTRY_ENTITY, template):\n", " service = get_random_choice_from_list(services)\n", " if detect_entity(CAR_ENTITY, template):\n", " car = get_random_choice_from_list(cars)\n", " if detect_entity(GADGET_ENTITY, template):\n", " gadget = get_random_choice_from_list(gadgets)\n", " if detect_entity(STOCK_ENTITY, template):\n", " stock = get_random_choice_from_list(stocks)\n", " if detect_entity(MONEY_ENTITY, template):\n", " money = get_random_choice_from_list(moneys)\n", " if detect_entity(FINANCE_ENTITY, template):\n", " finance = get_random_choice_from_list(finances)\n", " if detect_entity(TRAVEL_ENTITY, template):\n", " travel = get_random_choice_from_list(travels)\n", " if detect_entity(FOOD_ENTITY, template):\n", " food = get_random_choice_from_list(foods)\n", " if detect_entity(RESTAURANT_ENTITY, template):\n", " restaurant = get_random_choice_from_list(restaurants)\n", " if detect_entity(SPORTS_TERMS_MISSING_ENTITY, template):\n", " sports_term = get_random_choice_from_list(sports_terms_missing)\n", " if detect_entity(LOCATIONS_AND_LANDMARKS_ENTITY, template):\n", " location_and_landmark = get_random_choice_from_list(locations_and_landmarks)\n", " if detect_entity(ACTIVTIES_AND_EVENTS_ENTITY, template):\n", " activity_and_event = get_random_choice_from_list(activities_and_events)\n", " if detect_entity(FOOD_MISSING_ENTITY, template):\n", " food_m = get_random_choice_from_list(food_missing)\n", " if detect_entity(TRANSPORT_AND_DIRECTIONS_ENTITY, template):\n", " transport_and_direction = get_random_choice_from_list(transport_and_directions)\n", "\n", " if detect_entity(CITY_ENTITY, template):\n", " city=get_sample_from_cities(city_info, city_weights, actual_threshold=0.7)\n", " if detect_entity(STATE_ENTITY, template):\n", " state=get_sample_from_states(state_info, actual_threshold=0.5)\n", " if detect_entity(CITY_STATE_ENTITY, template):\n", " city_state=get_sample_from_cities_and_states(city_state_code_info, city_state_name_info, state_code_threshold=0.8)\n", " \n", " query = template.format(person=person,\n", " organization=organization,\n", " city=city,\n", " state=state,\n", " city_state=city_state,\n", " product=product,\n", " country=country,\n", " service=service,\n", " car=car,\n", " gadget=gadget,\n", " stock=stock,\n", " money=money,\n", " finance=finance,\n", " travel=travel,\n", " food=food,\n", " restaurant=restaurant,\n", " sports_term=sports_term,\n", " location_and_landmark=location_and_landmark,\n", " activity_and_event=activity_and_event,\n", " food_m=food_m,\n", " transport_and_direction=transport_and_direction\n", " )\n", " tokens, ner_labels = tokenize_and_label(query, city, state, city_state, organization, person)\n", " if query_counter.get(query, 0) == 0:\n", " queries_with_labels.append((query, tokens, ner_labels))\n", " query_counter.update([query])\n", " cnt += 1\n", " return queries_with_labels" ] }, { "cell_type": "code", "execution_count": 34, "id": "fad4a249-7151-4c50-833d-9584819b4105", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "complted generating 0 queries\n", "complted generating 10000 queries\n", "complted generating 10000 queries\n", "complted generating 10000 queries\n", "complted generating 20000 queries\n", "complted generating 30000 queries\n", "complted generating 40000 queries\n", "complted generating 40000 queries\n", "complted generating 40000 queries\n", "complted generating 40000 queries\n", "complted generating 40000 queries\n", "complted generating 40000 queries\n", "complted generating 40000 queries\n", "complted generating 40000 queries\n", "complted generating 40000 queries\n", "complted generating 50000 queries\n", "complted generating 50000 queries\n", "complted generating 50000 queries\n", "complted generating 60000 queries\n", "complted generating 60000 queries\n", "complted generating 60000 queries\n", "complted generating 60000 queries\n", "complted generating 60000 queries\n", "complted generating 60000 queries\n", "complted generating 60000 queries\n", "complted generating 60000 queries\n", "complted generating 60000 queries\n", "complted generating 60000 queries\n", "complted generating 60000 queries\n", "complted generating 60000 queries\n", "complted generating 60000 queries\n", "complted generating 60000 queries\n", "complted generating 60000 queries\n", "complted generating 60000 queries\n", "complted generating 60000 queries\n", "complted generating 60000 queries\n", "complted generating 70000 queries\n", "complted generating 70000 queries\n", "complted generating 70000 queries\n", "complted generating 70000 queries\n", "complted generating 80000 queries\n", "complted generating 80000 queries\n", "complted generating 80000 queries\n", "complted generating 80000 queries\n", "complted generating 80000 queries\n", "complted generating 80000 queries\n", "complted generating 80000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 90000 queries\n", "complted generating 100000 queries\n", "complted generating 100000 queries\n", "complted generating 100000 queries\n", "complted generating 100000 queries\n", "complted generating 110000 queries\n", "complted generating 120000 queries\n", "complted generating 120000 queries\n", "complted generating 130000 queries\n", "complted generating 130000 queries\n", "complted generating 130000 queries\n", "complted generating 130000 queries\n", "complted generating 130000 queries\n", "complted generating 130000 queries\n", "complted generating 130000 queries\n", "complted generating 130000 queries\n", "complted generating 130000 queries\n", "complted generating 130000 queries\n", "complted generating 130000 queries\n", "complted generating 130000 queries\n", "complted generating 140000 queries\n", "complted generating 140000 queries\n", "complted generating 140000 queries\n", "complted generating 140000 queries\n", "complted generating 140000 queries\n", "complted generating 140000 queries\n", "complted generating 140000 queries\n", "complted generating 140000 queries\n", "complted generating 140000 queries\n", "complted generating 140000 queries\n", "complted generating 140000 queries\n", "complted generating 140000 queries\n", "complted generating 140000 queries\n", "complted generating 140000 queries\n", "complted generating 140000 queries\n", "complted generating 140000 queries\n", "complted generating 140000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 150000 queries\n", "complted generating 160000 queries\n", "complted generating 160000 queries\n", "complted generating 160000 queries\n", "complted generating 160000 queries\n", "complted generating 160000 queries\n", "complted generating 170000 queries\n", "complted generating 170000 queries\n", "complted generating 170000 queries\n", "complted generating 170000 queries\n", "complted generating 170000 queries\n", "complted generating 170000 queries\n", "complted generating 170000 queries\n", "complted generating 170000 queries\n", "complted generating 170000 queries\n", "complted generating 170000 queries\n", "complted generating 180000 queries\n", "complted generating 180000 queries\n", "complted generating 180000 queries\n", "complted generating 180000 queries\n", "complted generating 180000 queries\n", "complted generating 180000 queries\n", "complted generating 190000 queries\n", "complted generating 190000 queries\n", "complted generating 200000 queries\n", "complted generating 200000 queries\n", "complted generating 200000 queries\n", "complted generating 200000 queries\n", "complted generating 200000 queries\n", "complted generating 200000 queries\n", "complted generating 200000 queries\n", "complted generating 200000 queries\n", "complted generating 210000 queries\n", "complted generating 220000 queries\n", "complted generating 220000 queries\n", "complted generating 220000 queries\n", "complted generating 230000 queries\n", "complted generating 230000 queries\n", "complted generating 230000 queries\n", "complted generating 230000 queries\n", "complted generating 230000 queries\n", "complted generating 230000 queries\n", "complted generating 230000 queries\n", "complted generating 240000 queries\n", "complted generating 240000 queries\n", "complted generating 240000 queries\n", "complted generating 240000 queries\n", "complted generating 240000 queries\n", "complted generating 250000 queries\n", "complted generating 250000 queries\n", "complted generating 250000 queries\n", "complted generating 250000 queries\n", "complted generating 250000 queries\n", "complted generating 260000 queries\n", "complted generating 260000 queries\n", "complted generating 260000 queries\n", "complted generating 260000 queries\n", "complted generating 260000 queries\n", "complted generating 260000 queries\n", "complted generating 260000 queries\n", "complted generating 260000 queries\n", "complted generating 260000 queries\n", "complted generating 260000 queries\n", "complted generating 260000 queries\n", "complted generating 260000 queries\n", "complted generating 260000 queries\n", "complted generating 260000 queries\n", "complted generating 260000 queries\n", "complted generating 260000 queries\n", "complted generating 270000 queries\n", "complted generating 280000 queries\n", "complted generating 290000 queries\n", "complted generating 290000 queries\n" ] } ], "source": [ "queries_with_labels = generate_queries(templates, n_queries=300000)" ] }, { "cell_type": "code", "execution_count": 35, "id": "f1270765-2a56-4dda-9e25-f0f6ac26f473", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "300000" ] }, "execution_count": 35, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(queries_with_labels)" ] }, { "cell_type": "code", "execution_count": 36, "id": "4145760e-8b8c-4896-a072-87ac5484dcf3", "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>query</th>\n", " <th>tokens</th>\n", " <th>ner_tags</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>Best upgrades for Volkswagen Golf</td>\n", " <td>[Best, upgrades, for, Volkswagen, Golf]</td>\n", " <td>[0, 0, 0, 0, 0]</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>top-rated conc venues</td>\n", " <td>[top, -, rated, conc, venues]</td>\n", " <td>[0, 0, 0, 0, 0]</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>nearest arena</td>\n", " <td>[nearest, arena]</td>\n", " <td>[0, 0]</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>saloo</td>\n", " <td>[saloo]</td>\n", " <td>[0]</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>tourn in San Diego</td>\n", " <td>[tourn, in, San, Diego]</td>\n", " <td>[0, 0, 5, 6]</td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>299995</th>\n", " <td>movie theaters Albany, OR</td>\n", " <td>[movie, theaters, Albany, ,, OR]</td>\n", " <td>[0, 0, 9, 10, 10]</td>\n", " </tr>\n", " <tr>\n", " <th>299996</th>\n", " <td>home improvement Meridian</td>\n", " <td>[home, improvement, Meridian]</td>\n", " <td>[0, 0, 5]</td>\n", " </tr>\n", " <tr>\n", " <th>299997</th>\n", " <td>ice cream The Villages FL</td>\n", " <td>[ice, cream, The, Villages, FL]</td>\n", " <td>[0, 0, 9, 10, 10]</td>\n", " </tr>\n", " <tr>\n", " <th>299998</th>\n", " <td>weather Caldwell ID</td>\n", " <td>[weather, Caldwell, ID]</td>\n", " <td>[0, 9, 10]</td>\n", " </tr>\n", " <tr>\n", " <th>299999</th>\n", " <td>Flagstaff, Arizona salads</td>\n", " <td>[Flagstaff, ,, Arizona, salads]</td>\n", " <td>[9, 10, 10, 0]</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>300000 rows × 3 columns</p>\n", "</div>" ], "text/plain": [ " query \\\n", "0 Best upgrades for Volkswagen Golf \n", "1 top-rated conc venues \n", "2 nearest arena \n", "3 saloo \n", "4 tourn in San Diego \n", "... ... \n", "299995 movie theaters Albany, OR \n", "299996 home improvement Meridian \n", "299997 ice cream The Villages FL \n", "299998 weather Caldwell ID \n", "299999 Flagstaff, Arizona salads \n", "\n", " tokens ner_tags \n", "0 [Best, upgrades, for, Volkswagen, Golf] [0, 0, 0, 0, 0] \n", "1 [top, -, rated, conc, venues] [0, 0, 0, 0, 0] \n", "2 [nearest, arena] [0, 0] \n", "3 [saloo] [0] \n", "4 [tourn, in, San, Diego] [0, 0, 5, 6] \n", "... ... ... \n", "299995 [movie, theaters, Albany, ,, OR] [0, 0, 9, 10, 10] \n", "299996 [home, improvement, Meridian] [0, 0, 5] \n", "299997 [ice, cream, The, Villages, FL] [0, 0, 9, 10, 10] \n", "299998 [weather, Caldwell, ID] [0, 9, 10] \n", "299999 [Flagstaff, ,, Arizona, salads] [9, 10, 10, 0] \n", "\n", "[300000 rows x 3 columns]" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# queries_with_labels[:10]\n", "df_ner_examples = pd.DataFrame(queries_with_labels, columns=['query', 'tokens', 'ner_tags'])\n", "df_ner_examples" ] }, { "cell_type": "code", "execution_count": 37, "id": "a6b071e2-c7bf-4891-a662-5c14898da9f8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ner_tags\n", "3 107543\n", "1 73502\n", "2 57480\n", "4 48084\n", "5 7978\n", "0 4194\n", "6 995\n", "7 134\n", "8 90\n", "Name: count, dtype: int64" ] }, "execution_count": 37, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_ner_examples['ner_tags'].apply(lambda tags: len([tag for tag in tags if tag > 4])).value_counts()" ] }, { "cell_type": "code", "execution_count": 38, "id": "022690de-9348-4d98-9007-b730571d6d6a", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{0: 'O',\n", " 1: 'B-PER',\n", " 2: 'I-PER',\n", " 3: 'B-ORG',\n", " 4: 'I-ORG',\n", " 5: 'B-CITY',\n", " 6: 'I-CITY',\n", " 7: 'B-STATE',\n", " 8: 'I-STATE',\n", " 9: 'B-CITYSTATE',\n", " 10: 'I-CITYSTATE'}" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "label_map" ] }, { "cell_type": "code", "execution_count": 39, "id": "d5d49f63-72c7-41da-a16c-536789713297", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "query\n", "Best upgrades for Volkswagen Golf 1\n", "Rochester Philharmonic Orchestra in Haverhill, MA 1\n", "Where to find the best noodles in Springfield, MA? 1\n", "home improvement services Kenosha WI 1\n", "home services Ahwatukee Foothills 1\n", " ..\n", "burgers Peoria AZ 1\n", "real estate agents nashville 1\n", "florist Largo, FL 1\n", "hours of Newark Data Solutions in Cicero 1\n", "Flagstaff, Arizona salads 1\n", "Name: count, Length: 300000, dtype: int64" ] }, "execution_count": 39, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df_ner_examples['query'].value_counts()" ] }, { "cell_type": "code", "execution_count": 40, "id": "2de441c8-4e9c-4b59-9bcc-cfbe009718bf", "metadata": {}, "outputs": [], "source": [ "df_ner_examples.to_csv(\"../data/df_ner_examples_v3.csv\", index=False)" ] }, { "cell_type": "code", "execution_count": 41, "id": "df043822-779c-4f9c-89eb-b331e2b0de19", "metadata": {}, "outputs": [], "source": [ "# useful for post processing to standardize the city names\n", "def build_lookup(dataframe):\n", " # Initialize an empty dictionary for the lookup\n", " lookup = {}\n", " \n", " # Iterate over each row in the DataFrame\n", " for index, row in dataframe.iterrows():\n", " city_name = row['city_name']\n", " alternate_names = row['alternate_names']\n", " \n", " # Iterate over the list of alternate names and map them to the city_name\n", " for alt_name in alternate_names:\n", " lookup[alt_name.lower()] = city_name # Convert alternate names to lowercase for consistency\n", " \n", " return lookup\n", "\n", "city_alternate_to_city_lkp = build_lookup(city_states_data)" ] }, { "cell_type": "code", "execution_count": 42, "id": "62a392e3-e18e-470f-9f95-ad35ebaebca8", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1356" ] }, "execution_count": 42, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(city_alternate_to_city_lkp)" ] }, { "cell_type": "code", "execution_count": 43, "id": "85bdeff1-a3f2-443e-a31b-d80e836c6ebe", "metadata": {}, "outputs": [], "source": [ "# !python -m pip install onnxruntime" ] }, { "cell_type": "code", "execution_count": 44, "id": "689e6844-2a90-4b7a-a9a5-bb298dce2b70", "metadata": {}, "outputs": [], "source": [ "# !python -m pip freeze| grep onnxruntime" ] }, { "cell_type": "code", "execution_count": 45, "id": "fc61067c-6e8a-499a-9d08-07fb4fb0eb2f", "metadata": {}, "outputs": [], "source": [ "# !mkdir ../models" ] }, { "cell_type": "code", "execution_count": 46, "id": "74bca5a8-0bb0-46c1-8429-598e172f34af", "metadata": {}, "outputs": [], "source": [ "import onnxruntime as ort\n", "import numpy as np\n", "from transformers import AutoTokenizer, BertTokenizer\n", "\n", "# Download the ONNX model\n", "# model_url = \"https://huggingface.co/Xenova/bert-base-NER/resolve/main/onnx/model_quantized.onnx\"\n", "# model_url = \"https://huggingface.co/Mozilla/distilbert-NER-LoRA/resolve/main/onnx/model_quantized.onnx\"\n", "model_url = \"https://huggingface.co/Mozilla/distilbert-uncased-NER-LoRA/resolve/main/onnx/model_quantized.onnx\"\n", "# model_url = \"https://huggingface.co/chidamnat2002/distilbert-uncased-NER-LoRA/resolve/main/onnx/model_quantized.onnx\"\n", "# model_path = \"../models/distilbert-NER-LoRA.onnx\"\n", "model_path = \"../models/distilbert-uncased-NER-LoRA.onnx\"\n", "\n", "# Download the ONNX model if not already present\n", "response = requests.get(model_url)\n", "with open(model_path, 'wb') as f:\n", " f.write(response.content)\n", "\n", "# Load the ONNX model using ONNX Runtime\n", "session = ort.InferenceSession(model_path)\n", "\n", "# Load the tokenizer (assuming it's based on BERT)\n", "# tokenizer = BertTokenizer.from_pretrained(\"Mozilla/distilbert-NER-LoRA\")\n", "tokenizer = AutoTokenizer.from_pretrained(\"Mozilla/distilbert-uncased-NER-LoRA\")" ] }, { "cell_type": "code", "execution_count": 47, "id": "838001d1-a252-4a4f-bfab-8c7698b7c79b", "metadata": {}, "outputs": [], "source": [ "def compute_model_inputs_and_outputs(session, tokenizer, query):\n", " # Tokenize the input\n", " # inputs = tokenizer(query, return_tensors=\"np\", truncation=True, padding=True)\n", " inputs = tokenizer(query, return_tensors=\"np\", truncation=True, padding='max_length', max_length=64)\n", " # is_split_into_words=True,\n", " # truncation=True,\n", " # padding='max_length',\n", " # max_length=64\n", " \n", " # The ONNX model expects 'input_ids', 'attention_mask', and 'token_type_ids'\n", " # Convert all necessary inputs to numpy arrays and prepare the input feed\n", " input_feed = {\n", " 'input_ids': inputs['input_ids'].astype(np.int64),\n", " 'attention_mask': inputs['attention_mask'].astype(np.int64),\n", " # 'token_type_ids': inputs['token_type_ids'].astype(np.int64) # Some models might not need this; check if it's really required\n", " }\n", " \n", " # Run inference with the ONNX model\n", " outputs = session.run(None, input_feed)\n", " # print(outputs)\n", " return inputs, outputs\n" ] }, { "cell_type": "code", "execution_count": 48, "id": "f66190d3-5601-4593-b7b9-0eebde13e23e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{0: 'O',\n", " 1: 'B-PER',\n", " 2: 'I-PER',\n", " 3: 'B-ORG',\n", " 4: 'I-ORG',\n", " 5: 'B-CITY',\n", " 6: 'I-CITY',\n", " 7: 'B-STATE',\n", " 8: 'I-STATE',\n", " 9: 'B-CITYSTATE',\n", " 10: 'I-CITYSTATE'}" ] }, "execution_count": 48, "metadata": {}, "output_type": "execute_result" } ], "source": [ "label_map" ] }, { "cell_type": "code", "execution_count": 49, "id": "08ecb315-3896-4a7e-8c03-37e3ecb1fa9a", "metadata": {}, "outputs": [], "source": [ "## With Xenova/bert-base-NER\n", "# Number of examples = 349\n", "# #hits = 135; #hit rate = 0.3868194842406877\n", "\n", "## After finetuning the Mozilla/distilbert-NER-LoRA\n", "#hits = 220; #hit rate = 0.6303724928366762\n", "\n", "## After finetuning the chidamnat2002/distilbert-uncased-NER-LoRA\n", "#hits = 207; #hit rate = 0.5931232091690545\n", "\n", "## After finetuning the Mozilla/distilbert-uncased-NER-LoRA\n", "#hits = 252; #hit rate = 0.7220630372492837" ] }, { "cell_type": "code", "execution_count": 50, "id": "1eed2554-784c-4f49-aad5-72b795f19295", "metadata": {}, "outputs": [], "source": [ "# len(missing_locations)" ] }, { "cell_type": "code", "execution_count": 51, "id": "feaed0b3-5fb8-4686-b57a-3a8d9764ec79", "metadata": { "scrolled": true }, "outputs": [], "source": [ "# print(missing_locations)" ] }, { "cell_type": "code", "execution_count": null, "id": "d04d5258-16b4-4773-b585-b5f31db3926c", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "id": "ef09b219-dd01-4d66-92e2-c438935e8654", "metadata": {}, "source": [ "#### Looking into CONLL 2003 dataset" ] }, { "cell_type": "code", "execution_count": 52, "id": "4233afed-374f-4f2f-baaa-078447959367", "metadata": {}, "outputs": [], "source": [ "from datasets import load_dataset, Dataset\n", "import re\n", "\n", "# Load the CoNLL-2003 dataset\n", "dataset = load_dataset(\"conll2003\")\n", "\n", "loc_examples = dataset" ] }, { "cell_type": "code", "execution_count": 53, "id": "14216057-228f-467a-aa8e-02108d56cb92", "metadata": {}, "outputs": [], "source": [ "# dataset['train'].to_pandas()" ] }, { "cell_type": "code", "execution_count": 54, "id": "e259586a-f67b-42b2-9665-a571da352f57", "metadata": {}, "outputs": [], "source": [ "# dataset['train']" ] }, { "cell_type": "code", "execution_count": 55, "id": "12e91919-6dc4-4ad3-a388-e5b90d4efa79", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Dataset({\n", " features: ['tokens', 'ner_tags'],\n", " num_rows: 300000\n", "})\n", "{'tokens': ['Best', 'upgrades', 'for', 'Volkswagen', 'Golf'], 'ner_tags': [0, 0, 0, 0, 0]}\n" ] } ], "source": [ "synthetic_loc_dataset = Dataset.from_pandas(df_ner_examples.drop('query', axis=1))\n", "print(synthetic_loc_dataset)\n", "\n", "print(synthetic_loc_dataset[0])" ] }, { "cell_type": "code", "execution_count": null, "id": "0d91ba34-cb67-418a-8a4e-4b442b144be6", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 56, "id": "496a76a7-3329-4849-affa-63166d427183", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'id': '0',\n", " 'tokens': ['EU',\n", " 'rejects',\n", " 'German',\n", " 'call',\n", " 'to',\n", " 'boycott',\n", " 'British',\n", " 'lamb',\n", " '.'],\n", " 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0]}" ] }, "execution_count": 56, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# loc_dataset = dataset['train'].filter(lambda example: 5 in example['ner_tags'])\n", "loc_dataset = dataset['train']\n", "loc_dataset_filtered = loc_dataset.remove_columns(['pos_tags', 'chunk_tags'])\n", "\n", "# Set the format to ensure the order is 'id', 'tokens', and 'ner_tags'\n", "loc_dataset_filtered[0]" ] }, { "cell_type": "code", "execution_count": 57, "id": "42652aaf-399f-413f-a8f6-e082f1057e3f", "metadata": {}, "outputs": [], "source": [ "# loc_dataset_filtered[-1]" ] }, { "cell_type": "code", "execution_count": 58, "id": "c47584e0-0612-400b-81e9-212a61209b94", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "384a0ff272f643b992a18537611757e1", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Casting the dataset: 0%| | 0/300000 [00:00<?, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "{'tokens': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'ner_tags': Sequence(feature=ClassLabel(names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-CITY', 'I-CITY', 'B-STATE', 'I-STATE', 'B-CITYSTATE', 'I-CITYSTATE'], id=None), length=-1, id=None)}\n", "{'tokens': ['Best', 'upgrades', 'for', 'Volkswagen', 'Golf'], 'ner_tags': [0, 0, 0, 0, 0]}\n" ] } ], "source": [ "from datasets import concatenate_datasets\n", "\n", "from datasets import Sequence, ClassLabel, Value\n", "\n", "# Step 1: Get the full feature schema from synthetic_loc_dataset\n", "features = synthetic_loc_dataset.features\n", "\n", "# Step 2: Update the 'ner_tags' feature to use ClassLabel from loc_dataset_filtered\n", "# features['ner_tags'] = Sequence(feature=ClassLabel(names=loc_dataset_filtered.features['ner_tags'].feature.names))\n", "features['ner_tags'] = Sequence(feature=ClassLabel(names=list(label_map.values())))\n", "\n", "# Step 3: Cast synthetic_loc_dataset to the updated feature schema\n", "synthetic_loc_dataset = synthetic_loc_dataset.cast(features)\n", "\n", "# Check the updated features to confirm\n", "print(synthetic_loc_dataset.features)\n", "\n", "# Now concatenate the datasets\n", "# combined_dataset = concatenate_datasets([loc_dataset_filtered, synthetic_loc_dataset])\n", "\n", "# Verify the combined dataset\n", "print(synthetic_loc_dataset[0])\n" ] }, { "cell_type": "code", "execution_count": 59, "id": "15f8ec72-8a43-43f2-932a-ef76b5efb4d2", "metadata": {}, "outputs": [], "source": [ "# ClassLabel(names=loc_dataset_filtered.features['ner_tags'].feature.names)" ] }, { "cell_type": "code", "execution_count": 60, "id": "6e3b90ed-9bbf-4b8a-9990-b5db059de0ea", "metadata": {}, "outputs": [], "source": [ "# ClassLabel(names=list(label_map.values()))" ] }, { "cell_type": "code", "execution_count": 61, "id": "6138a427-f03b-4355-bdac-ffec783f5a2b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "300000" ] }, "execution_count": 61, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(synthetic_loc_dataset)" ] }, { "cell_type": "code", "execution_count": 62, "id": "caac8e36-6d1c-4a42-8acd-7e81f816fa9b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'tokens': ['saloo'], 'ner_tags': [0]}" ] }, "execution_count": 62, "metadata": {}, "output_type": "execute_result" } ], "source": [ "synthetic_loc_dataset[3]" ] }, { "cell_type": "code", "execution_count": 63, "id": "2aa98e69-bf5f-4bcc-b387-2abdc60a99be", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2b6e3b9859134f56875e0fb1cc95f475", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Map: 0%| | 0/300000 [00:00<?, ? examples/s]" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "synthetic_loc_dataset = synthetic_loc_dataset.map(\n", " lambda example, idx: {'id': idx}, # Assign running count as the new 'id'\n", " with_indices=True # Ensures we get an index for each example\n", ")" ] }, { "cell_type": "code", "execution_count": 64, "id": "5906e294-6a1b-436d-a229-628f99190887", "metadata": {}, "outputs": [ { "data": { "text/html": [ "<div>\n", "<style scoped>\n", " .dataframe tbody tr th:only-of-type {\n", " vertical-align: middle;\n", " }\n", "\n", " .dataframe tbody tr th {\n", " vertical-align: top;\n", " }\n", "\n", " .dataframe thead th {\n", " text-align: right;\n", " }\n", "</style>\n", "<table border=\"1\" class=\"dataframe\">\n", " <thead>\n", " <tr style=\"text-align: right;\">\n", " <th></th>\n", " <th>tokens</th>\n", " <th>ner_tags</th>\n", " <th>id</th>\n", " </tr>\n", " </thead>\n", " <tbody>\n", " <tr>\n", " <th>0</th>\n", " <td>[Best, upgrades, for, Volkswagen, Golf]</td>\n", " <td>[0, 0, 0, 0, 0]</td>\n", " <td>0</td>\n", " </tr>\n", " <tr>\n", " <th>1</th>\n", " <td>[top, -, rated, conc, venues]</td>\n", " <td>[0, 0, 0, 0, 0]</td>\n", " <td>1</td>\n", " </tr>\n", " <tr>\n", " <th>2</th>\n", " <td>[nearest, arena]</td>\n", " <td>[0, 0]</td>\n", " <td>2</td>\n", " </tr>\n", " <tr>\n", " <th>3</th>\n", " <td>[saloo]</td>\n", " <td>[0]</td>\n", " <td>3</td>\n", " </tr>\n", " <tr>\n", " <th>4</th>\n", " <td>[tourn, in, San, Diego]</td>\n", " <td>[0, 0, 5, 6]</td>\n", " <td>4</td>\n", " </tr>\n", " <tr>\n", " <th>...</th>\n", " <td>...</td>\n", " <td>...</td>\n", " <td>...</td>\n", " </tr>\n", " <tr>\n", " <th>299995</th>\n", " <td>[movie, theaters, Albany, ,, OR]</td>\n", " <td>[0, 0, 9, 10, 10]</td>\n", " <td>299995</td>\n", " </tr>\n", " <tr>\n", " <th>299996</th>\n", " <td>[home, improvement, Meridian]</td>\n", " <td>[0, 0, 5]</td>\n", " <td>299996</td>\n", " </tr>\n", " <tr>\n", " <th>299997</th>\n", " <td>[ice, cream, The, Villages, FL]</td>\n", " <td>[0, 0, 9, 10, 10]</td>\n", " <td>299997</td>\n", " </tr>\n", " <tr>\n", " <th>299998</th>\n", " <td>[weather, Caldwell, ID]</td>\n", " <td>[0, 9, 10]</td>\n", " <td>299998</td>\n", " </tr>\n", " <tr>\n", " <th>299999</th>\n", " <td>[Flagstaff, ,, Arizona, salads]</td>\n", " <td>[9, 10, 10, 0]</td>\n", " <td>299999</td>\n", " </tr>\n", " </tbody>\n", "</table>\n", "<p>300000 rows × 3 columns</p>\n", "</div>" ], "text/plain": [ " tokens ner_tags id\n", "0 [Best, upgrades, for, Volkswagen, Golf] [0, 0, 0, 0, 0] 0\n", "1 [top, -, rated, conc, venues] [0, 0, 0, 0, 0] 1\n", "2 [nearest, arena] [0, 0] 2\n", "3 [saloo] [0] 3\n", "4 [tourn, in, San, Diego] [0, 0, 5, 6] 4\n", "... ... ... ...\n", "299995 [movie, theaters, Albany, ,, OR] [0, 0, 9, 10, 10] 299995\n", "299996 [home, improvement, Meridian] [0, 0, 5] 299996\n", "299997 [ice, cream, The, Villages, FL] [0, 0, 9, 10, 10] 299997\n", "299998 [weather, Caldwell, ID] [0, 9, 10] 299998\n", "299999 [Flagstaff, ,, Arizona, salads] [9, 10, 10, 0] 299999\n", "\n", "[300000 rows x 3 columns]" ] }, "execution_count": 64, "metadata": {}, "output_type": "execute_result" } ], "source": [ "synthetic_loc_dataset.to_pandas()" ] }, { "cell_type": "code", "execution_count": 65, "id": "46c0d423-3b8c-47ed-a8ae-a3316cd78bd0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'tokens': ['Flagstaff', ',', 'Arizona', 'salads'],\n", " 'ner_tags': [9, 10, 10, 0],\n", " 'id': 299999}" ] }, "execution_count": 65, "metadata": {}, "output_type": "execute_result" } ], "source": [ "synthetic_loc_dataset[-1]" ] }, { "cell_type": "code", "execution_count": 66, "id": "c35b1a0b-303c-4eee-bc31-770872c212e5", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "5a14788e479b4c83bc473f580edd7d92", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Creating parquet from Arrow format: 0%| | 0/300 [00:00<?, ?ba/s]" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "36125163" ] }, "execution_count": 66, "metadata": {}, "output_type": "execute_result" } ], "source": [ "synthetic_loc_dataset.to_parquet(\"../data/synthetic_loc_dataset_v3.parquet\")" ] }, { "cell_type": "code", "execution_count": 67, "id": "d33bb9a1-bd49-49cd-aa90-5428d46fbad7", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[{'entity': 'B-STATE', 'score': np.float32(0.827376), 'index': 1, 'word': 'new', 'start': 0, 'end': 3}, {'entity': 'I-STATE', 'score': np.float32(0.69074583), 'index': 2, 'word': 'york', 'start': 4, 'end': 8}]\n" ] } ], "source": [ "from transformers import AutoTokenizer, AutoModelForTokenClassification\n", "from transformers import pipeline\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"Mozilla/distilbert-uncased-NER-LoRA\")\n", "model = AutoModelForTokenClassification.from_pretrained(\"Mozilla/distilbert-uncased-NER-LoRA\")\n", "\n", "nlp = pipeline(\"ner\", model=model, tokenizer=tokenizer)\n", "example = \"New York\"\n", "\n", "ner_results = nlp(example)\n", "print(ner_results)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "32524933-23f7-41ae-8597-da0300e6ac60", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.0" } }, "nbformat": 4, "nbformat_minor": 5 }