Explore whether the weather keywords and locations are captured correctly

In [1]:
import pandas as pd
import requests
from bs4 import BeautifulSoup
import re
from datasets import load_dataset, Dataset

In [2]:
import numpy as np
import random
from collections import Counter

#### Read the data/geonames-cities-states.json

In [3]:
import json 

def get_geonames_city_state_data():
    geonames_file = "../data/geonames-cities-states.json"
    with open(geonames_file, 'r') as f:
        geonames_dict = json.load(f)
    
    
    cities_data = pd.DataFrame(geonames_dict['cities'])\
                    .rename(columns={'admin1_code': 'state_code', 'name': 'city_name', 'population': 'city_popln'})
    cities_data = cities_data[['id', 'state_code', 'city_name', 'city_popln', 'alternate_names']]
    states_data = pd.DataFrame(geonames_dict['states_by_abbr'].values())\
                    .rename(columns={'admin1_code': 'state_code', 'name': 'state_name'})
    states_data = states_data[['state_code', 'state_name']]
    city_states_data = cities_data.merge(states_data, how='left', on='state_code')
    city_states_data['city_weight'] = city_states_data['city_popln'] / city_states_data['city_popln'].sum()
    return city_states_data



In [4]:
city_states_data = get_geonames_city_state_data()
print(len(city_states_data))
city_states_data

962


Unnamed: 0,id,state_code,city_name,city_popln,alternate_names,state_name,city_weight
0,4049979,AL,Birmingham,212461,"[birmingham, bhm]",Alabama,0.001409
1,4050552,TN,Cordova,68779,[cordova],Tennessee,0.000456
2,4058553,AL,Decatur,55437,"[decatur, dcu]",Alabama,0.000368
3,4059102,AL,Dothan,68567,"[dhn, dothan]",Alabama,0.000455
4,4067994,AL,Hoover,84848,[hoover],Alabama,0.000563
...,...,...,...,...,...,...,...
957,11748973,GA,Stonecrest,50000,[stonecrest],Georgia,0.000332
958,11838960,CA,Valley Glen,60000,[valley glen],California,0.000398
959,11979227,AZ,Encanto,54614,[encanto],Arizona,0.000362
960,11979238,AZ,Central City,58161,[central city],Arizona,0.000386


In [5]:
city_states_data.sort_values('city_weight', ascending=False)

Unnamed: 0,id,state_code,city_name,city_popln,alternate_names,state_name,city_weight
554,5128581,NY,New York City,8804190,"[the big apple, new york, new york city, nyc, ...",New York,0.058393
718,5368361,CA,Los Angeles,3898747,"[los angeles, lax, l.a., la]",California,0.025858
522,5110302,NY,Brooklyn,2736074,"[bk, borough of brooklyn, brooklyn]",New York,0.018147
359,4887398,IL,Chicago,2696555,"[chi, chicago]",Illinois,0.017885
266,4699066,TX,Houston,2304580,"[houston, hou]",Texas,0.015285
...,...,...,...,...,...,...,...
260,4692883,TX,Galveston,50180,"[galveston, gls]",Texas,0.000333
758,5384690,CA,Poway,50157,[poway],California,0.000333
470,5025264,MN,Edina,50138,[edina],Minnesota,0.000333
475,5037790,MN,Minnetonka Mills,50117,[minnetonka mills],Minnesota,0.000332


In [6]:
city_weights = city_states_data[['city_name', 'city_weight']].set_index('city_name').to_dict()['city_weight']
# city_weights

In [7]:
city_info = city_states_data[['city_name', 'alternate_names']].set_index('city_name').to_dict()['alternate_names']
state_info = city_states_data[['state_code', 'state_name']].set_index('state_code').to_dict()['state_name']
city_state_code_info = city_states_data[['city_name', 'state_code', 'city_weight']].copy()
city_state_name_info = city_states_data[['city_name', 'state_name', 'city_weight']].copy()

In [8]:
# city_info

In [9]:
label_map = {
        0: "O",        # Outside any named entity
        1: "B-PER",    # Beginning of a person entity
        2: "I-PER",    # Inside a person entity
        3: "B-ORG",    # Beginning of an organization entity
        4: "I-ORG",    # Inside an organization entity
        5: "B-CITY",    # Beginning of a city entity
        6: "I-CITY",    # Inside a city entity
        7: "B-STATE",    # Beginning of a state entity
        8: "I-STATE",    # Inside a state entity
        9: "B-CITYSTATE",   # Beginning of a city_state entity
       10: "I-CITYSTATE",   # Inside a city_state entity
    }


persons = [
        'Donald Trump', 'John Smith', 'Roger Williams', 'Michelle Obama', 'Elon Musk',
        'Barack Obama', 'Bill Gates', 'Steve Jobs', 'Warren Buffett', 'Oprah Winfrey',
        'Jeff Bezos', 'Taylor Swift', 'Jennifer Lawrence', 'Brad Pitt', 'Leonardo DiCaprio',
        'Katy Perry', 'Tom Hanks', 'Emma Watson', 'Johnny Depp', 'Scarlett Johansson',
        'Mark Zuckerberg', 'Sheryl Sandberg', 'Ivanka Trump', 'Joe Biden', 'Kamala Harris',
        'Serena Williams', 'Michael Jordan', 'LeBron James', 'Tiger Woods', 'Cristiano Ronaldo',
        'Lionel Messi', 'Roger Federer', 'Usain Bolt', 'Simone Biles', 'Tom Brady',
        'Peyton Manning', 'David Beckham', 'Rafael Nadal', 'Novak Djokovic', 'Andy Murray',
        'George Clooney', 'Matt Damon', 'Julia Roberts', 'Angelina Jolie', 'Morgan Freeman',
        'Chris Hemsworth', 'Dwayne Johnson', 'Vin Diesel', 'Keanu Reeves', 'Robert Downey Jr.',
        'Chris Evans', 'Will Smith', 'Johnny Cash', 'Bob Dylan', 'Paul McCartney',
        'Ringo Starr', 'John Lennon', 'George Harrison', 'Madonna', 'Prince',
        'Bruce Springsteen', 'Elton John', 'David Bowie', 'Whitney Houston', 'Celine Dion',
        'Marilyn Monroe', 'Audrey Hepburn', 'Albert Einstein', 'Isaac Newton', 'Marie Curie',
        'Galileo Galilei', 'Nikola Tesla', 'Stephen Hawking', 'Richard Feynman', 'Carl Sagan',
        'Neil Armstrong', 'Yuri Gagarin', 'Sally Ride', 'Jane Goodall', 'Charles Darwin',
        'Mahatma Gandhi', 'Nelson Mandela', 'Martin Luther King Jr.', 'Malala Yousafzai', 'Angela Merkel',
        'Theresa May', 'Vladimir Putin', 'Xi Jinping', 'Justin Trudeau', 'Jacinda Ardern',
        'Pope Francis', 'Dalai Lama', 'Queen Elizabeth II', 'Prince William', 'Prince Harry',
        'James Anderson', 'Michael Brown', 'David Clark', 'John Doe', 'Robert Evans',
    'Christopher Foster', 'William Garcia', 'Charles Hall', 'Joseph Harris', 'Daniel Jackson',
    'Matthew Johnson', 'George King', 'Anthony Lewis', 'Mark Miller', 'Paul Moore',
    'Steven Nelson', 'Kevin Perry', 'Thomas Reed', 'Brian Roberts', 'Jason Scott',
    'Andrew Smith', 'Joshua Thompson', 'Ryan Turner', 'Brandon Walker', 'Nicholas White',
    'Jonathan Young', 'Adam Baker', 'Justin Carter', 'Benjamin Collins', 'Aaron Cook',
    'Alexander Davis', 'Tyler Edwards', 'Zachary Fisher', 'Ethan Graham', 'Jacob Green',
    'Austin Hernandez', 'Mason Hill', 'Logan Hughes', 'Owen Jenkins', 'Lucas Kelly',
    'Nathan Lee', 'Caleb Long', 'Henry Martinez', 'Dylan Mitchell', 'Gabriel Morris',
    'Jack Murphy', 'Connor Myers', 'Liam Parker', 'Isaac Patterson', 'Evan Phillips',
    'Hunter Price', 'Noah Richardson', 'Samuel Rivera', 'Gavin Rogers', 'Aiden Ross',
    'Christian Russell', 'Ian Sanders', 'Eli Simmons', 'Chase Stewart', 'Cameron Sullivan',
    'Bryan Taylor', 'Cole Thomas', 'Jake Thompson', 'Luke Torres', 'Blake Turner',
    'Jesse Ward', 'Joel Watson', 'Derek Williams', 'Mitchell Wright', 'Dustin Young',
    'Megan Allen', 'Jennifer Bailey', 'Jessica Bennett', 'Emily Brooks', 'Sarah Campbell',
    'Amanda Carter', 'Rebecca Collins', 'Samantha Cooper', 'Stephanie Diaz', 'Rachel Evans',
    'Christine Flores', 'Laura Foster', 'Michelle Garcia', 'Amber Gonzales', 'Lisa Gray',
    'Kimberly Green', 'Heather Harris', 'Tiffany Henderson', 'Natalie Hernandez', 'Crystal Hill',
    'Victoria Hughes', 'Erica Jenkins', 'Nicole Johnson', 'Katherine Kelly', 'Danielle Lee',
    'Hannah Lewis', 'Melissa Lopez', 'Patricia Martin', 'Brittany Moore', 'Brenda Morgan',

    ]
organizations = [
        'Google Inc.', 'Apple Inc.', 'Amazon.com', 'Facebook Inc.', 'Microsoft Corporation',
        'Tesla Motors', 'Netflix Inc.', 'The New York Times', 'The Washington Post', 'Wall Street Journal',
        'Intel Corporation', 'Oracle Corporation', 'IBM', 'Coca-Cola Company', 'PepsiCo',
        'Starbucks', 'Walmart Inc.', 'Target Corporation', 'ExxonMobil', 'Shell Oil Company',
        'Ford Motor Company', 'General Motors', 'Toyota Motor Corporation', 'Volkswagen Group', 'BMW Group',
        'American Airlines', 'Delta Airlines', 'United Airlines', 'Boeing Company', 'Lockheed Martin',
        'SpaceX', 'NASA', 'Harvard University', 'Stanford University', 'Massachusetts Institute of Technology',
        'University of California, Berkeley', 'University of Oxford', 'University of Cambridge', 'Princeton University', 'Yale University',
        'University of Chicago', 'Columbia University', 'Johns Hopkins University', 'University of Southern California', 'University of Michigan',
        'Goldman Sachs', 'JPMorgan Chase', 'Citibank', 'Morgan Stanley', 'Bank of America',
        'Deloitte', 'Ernst & Young', 'PricewaterhouseCoopers', 'KPMG', 'McKinsey & Company',
        'Boston Consulting Group', 'Accenture', 'BlackRock', 'Fidelity Investments', 'Vanguard Group',
        'Nike Inc.', 'Adidas', 'Under Armour', 'Patagonia', 'The Walt Disney Company',
        'Time Warner', 'NBCUniversal', 'Sony Corporation', 'Warner Bros.', 'Paramount Pictures',
        'Universal Music Group', 'Sony Music Entertainment', 'Warner Music Group', 'Pfizer Inc.', 'Johnson & Johnson',
        'Novartis', 'Merck & Co.', 'GlaxoSmithKline', 'AstraZeneca', 'Moderna',
        'New York City Hospital', 'Los Angeles County Library', 'San Francisco Community College',
    'Miami International University', 'Chicago Regional Bank', 'Dallas Medical Center',
    'Boston Tech Solutions', 'Atlanta City Bank', 'Seattle Software Hub', 'Phoenix Energy Solutions',
    'Denver Financial Group', 'Houston General Hospital', 'Portland Health Services', 'Las Vegas Convention Center',
    'San Diego Software Innovations', 'Philadelphia Law Firm', 'Orlando Realty Group',
    'Austin Engineering Solutions', 'Cleveland City Schools', 'Detroit Manufacturing Hub',
    'Baltimore Technology Inc.', 'Minneapolis Insurance Group', 'St. Louis Transportation Services',
    'Tampa Healthcare Network', 'Pittsburgh Steelworks Corporation', 'Sacramento Business Ventures',
    'Indianapolis Marketing Solutions', 'Columbus Financial Advisors', 'Fort Worth Electric Company',
    'Charlotte Digital Marketing', 'Milwaukee Industrial Solutions', 'Memphis Logistics Services',
    'Washington DC Development', 'Nashville Business Enterprises', 'Louisville Fitness Center',
    'Kansas City Architectural Firm', 'Oklahoma City University', 'Virginia Beach Law Associates',
    'Raleigh Research Institute', 'Salt Lake City Analytics', 'Richmond Financial Group',
    'Newark Data Solutions', 'Anchorage Energy Solutions', 'Fresno Water Authority',
    'Omaha Financial Services', 'Colorado Springs Health Institute', 'Mesa Auto Parts',
    'Virginia Beach Shipping', 'Sacramento Community Center', 'Albuquerque Electronics Company',
    'Tucson Data Science Center', 'Miami Lakes Software Solutions', 'Wichita Steel Corporation',
    'Arlington Cybersecurity Group', 'Bakersfield Construction Services', 'Aurora Logistics Firm',
    'Anaheim Technology Hub', 'Santa Ana Healthcare Services', 'Riverside Manufacturing Co.',
    'St. Paul Medical Associates', 'Lexington University Hospital', 'Plano Technology Solutions',
    'Lincoln Manufacturing Inc.', 'Greensboro Industrial Partners', 'Jersey City Financial Group',
    'Chandler Electronics', 'Madison Biotechnology Solutions', 'Lubbock Medical Supplies',
    'Scottsdale Real Estate Group', 'Reno Venture Capitalists', 'Henderson Engineering Consultants',
    'Norfolk Health Services', 'Chesapeake Data Systems', 'Fremont Software Group',
    'Irvine Legal Services', 'San Bernardino Logistics Group', 'Boise Energy Technologies',
    'Spokane Steel Fabricators', 'Glendale Solar Power Corporation', 'Garland Medical Services',
    'Hialeah Shipping and Logistics', 'Chesapeake Financial Advisors', 'Frisco Software Hub',
    'McKinney Electronics Corporation', 'Gilbert Transportation Group', 'Baton Rouge Financial Services',
    'Shreveport Data Analytics', 'Mobile Business Solutions', 'Huntsville Rocket Technologies',
    'Knoxville Agricultural Partners', 'Dayton Software Innovations', 'Grand Rapids Healthcare Network',
    'Fort Lauderdale Construction Group', 'Tempe Electric Vehicles', 'Winston-Salem Marketing Firm',
    'Fayetteville Consulting Services', 'Springfield Realty Group', 'Yonkers Manufacturing Hub',
    'Augusta Insurance Group', 'Salem Solar Energy Solutions', 'Pasadena Legal Consultants',
    'Seattle Pacific University', 'San Diego Zoo', 'Portland Art Museum',
    'Boston Medical Group', 'Chicago Tribune', 'Dallas Cowboys Football Club',
    'Los Angeles Philharmonic Orchestra', 'New York University', 'Houston Community College',
    'Phoenix Solar Power', 'Denver Public Library', 'Miami International Airport',
    'Atlanta Symphony Orchestra', 'San Francisco Opera', 'Orlando City Soccer Club',
    'Nashville Symphony', 'Baltimore Ravens Football Team', 'Cleveland Clinic',
    'Pittsburgh Steelers Football Team', 'Detroit Institute of Arts',
    'Tampa Bay Buccaneers Football Club', 'St. Louis Cardinals Baseball Team',
    'Indianapolis Colts Football Team', 'Austin Film Society', 'Seattle Sounders Football Club',
    'Minneapolis Institute of Art', 'Charlotte Hornets Basketball Club', 'Portland Trail Blazers Basketball Team',
    'Las Vegas Convention and Visitors Authority', 'New Orleans Saints Football Club',
    'San Antonio Spurs Basketball Club', 'Philadelphia Eagles Football Club',
    'Kansas City Chiefs Football Team', 'Cincinnati Reds Baseball Club',
    'Memphis Grizzlies Basketball Team', 'Washington Wizards Basketball Club',
    'Milwaukee Bucks Basketball Club', 'Sacramento Kings Basketball Team',
    'Salt Lake City Ballet', 'Boise State University', 'Albuquerque International Balloon Fiesta',
    'Raleigh-Durham International Airport', 'Richmond Symphony', 'Fresno Pacific University',
    'Spokane Transit Authority', 'Henderson Engineering', 'Mesa Public Schools',
    'Scottsdale Museum of Contemporary Art', 'Chandler Regional Medical Center', 'Glendale Unified School District',
    'Riverside Community Hospital', 'Aurora Public Schools', 'Anaheim Ducks Hockey Team',
    'Santa Ana College', 'Stockton Unified School District', 'Irvine Company', 'San Bernardino Community College District',
    'Modesto Junior College', 'Bakersfield Condors Hockey Team', 'Fresno State University',
    'Chesapeake Energy Corporation', 'Omaha World-Herald', 'Tucson Medical Center',
    'Virginia Beach Public Schools', 'Norfolk Naval Shipyard', 'Newark Beth Israel Medical Center',
    'Fort Wayne Mad Ants Basketball Team', 'Fremont High School', 'Shreveport Regional Airport',
    'Mobile Public Library', 'Huntsville Hospital', 'Knoxville Symphony Orchestra',
    'Dayton International Airport', 'Grand Rapids Symphony', 'Winston-Salem Dash Baseball Team',
    'Fayetteville Technical Community College', 'Springfield Cardinals Baseball Team',
    'Augusta National Golf Club', 'Salem Health', 'Pasadena Playhouse', 'Yonkers Public Schools',
    'Boulder Community Health', 'Naperville North High School', 'Lansing Community College',
    'Reno-Tahoe International Airport', 'Columbia University Medical Center', 'Albany Law School',
    'Buffalo Sabres Hockey Team', 'Syracuse University', 'Toledo Museum of Art', 'Akron Public Schools',
    'Daytona International Speedway', 'Des Moines Public Library', 'Rochester Philharmonic Orchestra',
    'Flint Institute of Arts', 'Lincoln Memorial University', 'Baton Rouge Community College',
    'Chattanooga Symphony and Opera', 'Greenville Technical College', 'Cedar Rapids Opera Theatre',
    'Pensacola Naval Air Station'
    ]

products = [
    'iPhone', 'Samsung Galaxy', 'MacBook', 'PlayStation 5', 'Nike shoes', 
    'AirPods', 'Xbox Series X', 'Canon DSLR', 'GoPro', 'Adidas sneakers', 
    'Fitbit', 'Google Pixel', 'Kindle', 'Bose headphones', 'Sony TV', 
    'Dyson vacuum', 'KitchenAid mixer', 'Surface Pro', 'Roomba', 'Apple Watch'
]

countries = [
    'USA', 'France', 'Japan', 'Germany', 'Canada', 
    'Australia', 'Mexico', 'China', 'Brazil', 'India', 
    'Italy', 'Spain', 'South Korea', 'Russia', 'Netherlands', 
    'United Kingdom', 'Sweden', 'Norway', 'Switzerland', 'Argentina'
]

services = [
    'Netflix', 'Spotify', 'Uber', 'Amazon Prime', 'Google Drive', 
    'Zoom', 'Dropbox', 'Slack', 'LinkedIn', 'Disney+', 
    'YouTube Premium', 'Venmo', 'DoorDash', 'Postmates', 'Hulu', 
    'Skype', 'Grubhub', 'Twitch', 'Instacart', 'Lyft'
]

cars = [
    'Tesla Model S', 'Ford Mustang', 'Chevrolet Camaro', 'Toyota Corolla', 'Honda Civic', 
    'BMW 3 Series', 'Audi A4', 'Mercedes-Benz C-Class', 'Jeep Wrangler', 'Ford F-150', 
    'Hyundai Elantra', 'Mazda CX-5', 'Chevrolet Tahoe', 'Nissan Altima', 'Kia Sorento', 
    'Volkswagen Golf', 'Subaru Outback', 'Tesla Model 3', 'Dodge Charger', 'Volvo XC90'
]

gadgets = [
    'smartwatch', 'Bluetooth headphones', 'fitness tracker', 'smart speaker', 'tablet', 
    'laptop', 'gaming mouse', 'wireless charger', 'VR headset', 'noise-canceling headphones', 
    'dashcam', 'e-reader', 'action camera', 'portable hard drive', 'gaming console', 
    'mechanical keyboard', '4K monitor', 'digital camera', 'portable power bank', 'USB-C hub'
]

stocks = [
    'AAPL', 'GOOGL', 'AMZN', 'MSFT', 'TSLA', 
    'NFLX', 'FB', 'BABA', 'NVDA', 'JPM', 
    'V', 'PYPL', 'BRK.A', 'DIS', 'INTC', 
    'PFE', 'NKE', 'ORCL', 'VZ', 'BA'
]

moneys = [
    'cryptocurrency', 'cash', 'PayPal', 'credit card', 'Bitcoin', 
    'Ethereum', 'bank transfer', 'wire transfer', 'Western Union', 'Venmo', 
    'debit card', 'Zelle', 'Apple Pay', 'Google Pay', 'Coinbase', 
    'Tether', 'Litecoin', 'Dogecoin', 'cash app', 'Ripple'
]

finances = [
    '401(k)', 'IRA', 'mutual funds', 'mortgage', 'student loan', 
    'savings account', 'retirement fund', 'bond', 'annuity', 'index fund', 
    'Roth IRA', 'tax-free savings account', 'pension', 'trust fund', 'hedge fund', 
    'credit score', 'auto loan', 'home equity loan', 'personal loan', 'debt consolidation'
]

travels = [
    'flights', 'hotels', 'car rentals', 'vacation packages', 'cruise trips', 
    'road trips', 'train tickets', 'adventure tours', 'guided tours', 'backpacking trips',
    'honeymoon destinations', 'beach resorts', 'luxury travel', 'budget travel', 'camping gear', 
    'family vacations', 'ski trips', 'all-inclusive resorts', 'last-minute deals', 'travel insurance'
]

foods = [
    'pizza', 'sushi', 'burgers', 'pasta', 'salads', 
    'vegan food', 'barbecue', 'fried chicken', 'ramen', 'tacos', 
    'sandwiches', 'noodles', 'soups', 'cakes', 'ice cream', 
    'steak', 'seafood', 'breakfast food', 'brunch', 'desserts'
]

restaurants = [
    'Italian restaurants', 'Mexican restaurants', 'Japanese restaurants', 'Chinese restaurants', 'Indian restaurants', 
    'fast food chains', 'fine dining', 'vegan restaurants', 'steakhouses', 'seafood restaurants', 
    'barbecue joints', 'sushi bars', 'cafes', 'pizzerias', 'buffet restaurants', 
    'food trucks', 'family-friendly restaurants', 'gastropubs', 'brunch spots', 'diner'
]

## Additional partial terms
sports_terms_missing = [
    "footbal", "baske", "socce", "golf", "cricke", "rugby", "hocke", "tenni", 
    "swimmin", "athleti", "fishi", "basebal", "volleybal", "badminto", "maratho", 
    "skatin", "climbin", "racquetball", "bowlin", "darts", "gymnasti", "bikin"
]

locations_and_landmarks = [
    "statue", "museum", "plaza", "zoo", "church", "theater", "stadium", "mountain", 
    "park", "lake", "beach", "river", "palace", "cathedra", "mansion", "monument", 
    "temple", "observato", "canyon", "garden", "conservato", "boardwal", "forest", 
    "pier", "lighthouse", "arena"
]

activities_and_events = [
    "conc", "exhib", "meet", "parad", "festi", "tourn", "game", "sho", "even", 
    "gala", "confere", "seminar", "webina", "worksho", "lectur", "symposiu", 
    "screenin", "rall", "celebratio", "ceremon", "get-togethe", "perfor", 
    "gatherin", "competitio", "maratho", "speec", "workout", "showcas"
]

food_missing = [
    "sush", "pizz", "ramen", "bbq", "vega", "steak", "taco", "burg", "pasta", 
    "brunc", "desse", "drink", "grill", "bake", "buffet", "sandwich", "noodle", 
    "cafe", "taver", "gastro", "bistro", "del", "saloo", "barbecue", "snack", 
    "confectio", "pub"
]

transport_and_directions = [
    "direc", "map", "bus", "train", "car", "park", "taxi", "subwa", "fly", 
    "plane", "ticke", "pass", "ferr", "bicycl", "scoote", "shuttl", "walkin", 
    "rideshar", "transi", "toll", "metr", "road", "route", "stop"
]


In [10]:

def get_sample_from_cities(city_info, city_weights, actual_threshold=0.7):
    cities = list(city_info.keys())
    weights = [city_weights[city] for city in cities]
    city_random = random.choices(cities, weights=weights, k=1)[0]
    rand_val = random.random()
    if rand_val <= actual_threshold:
        return city_random
    return random.choice(city_info[city_random])

def get_sample_from_states(state_info, actual_threshold=0.5):
    states = list(state_info.keys())
    state_random = random.choice(states)
    rand_val = random.random()
    if rand_val <= actual_threshold:
        return state_random
    return random.choice([state_info[state_random]])

def get_sample_from_cities_and_states(city_state_code_info, city_state_name_info, state_code_threshold=0.8, comma_threshold=0.6):
    rand_val = random.random()
    if rand_val <= state_code_threshold:
        if rand_val <= comma_threshold:
            return ', '.join(city_state_code_info.sample(1, weights='city_weight', replace=True)[['city_name', 'state_code']].values.tolist()[0])
        else:
            return ' '.join(city_state_code_info.sample(1, weights='city_weight', replace=True)[['city_name', 'state_code']].values.tolist()[0])
    return ', '.join(city_state_name_info.sample(1, weights='city_weight', replace=True)[['city_name', 'state_name']].values.tolist()[0])

def get_random_choice_from_list(choices_list):
    return random.choice(choices_list)
    


In [11]:
# for _ in range(100):
#     print(get_sample_from_cities_and_states(city_state_code_info, city_state_name_info, state_code_threshold=0.8))

In [12]:
templates = [
    # Simple City-Based Queries
    "weather {city}",
    "{city} temperature",
    "sushi {city}",
    "ramen {city}",
    "pizza {city}",
    "plumber {city}",
    "electrician {city}",
    "roof repair {city}",
    "physio therapy {city}",
    "hospital {city}",
    "doctor {city}",
    "nurse {city}",
    "home improvement {city}",
    "home services {city}",
    "weather forecast {city}",
    "current weather {city}",
    "best restaurants {city}",
    "top yelp reviews {city}",
    "places to visit in {city}",
    "best cafes in {city}",
    "emergency services {city}",
    "gyms in {city}",
    "car repair {city}",
    "florist {city}",
    "lawyers in {city}",
    "real estate agents {city}",
    "hiking trails {city}",
    "parks in {city}",
    "movie theaters {city}",
    "top hotels in {city}",
    "events in {city} this weekend",
    "pharmacies {city}",

    # State-Based Queries
    "home services in {state}",
    "best restaurants in {state}",
    "real estate agents {state}",
    "roof repair services {state}",
    "hospitals in {state}",
    "weather {state}",
    "temperature {state}",
    "physio therapy {state}",
    "doctors in {state}",
    "top-rated plumbers {state}",
    "electricians {state}",
    "emergency services {state}",
    "sushi {state}",
    "ramen {state}",
    "pizza {state}",
    "parks in {state}",
    "hiking trails {state}",
    "pharmacies in {state}",
    "best cafes {state}",
    "movie theaters {state}",

    # City-State Combination Queries (Now using {city_state})
    "weather {city_state}",
    "{city_state} temperature",
    "sushi {city_state}",
    "plumber {city_state}",
    "best restaurants in {city_state}",
    "top-rated roof repair {city_state}",
    "hospital {city_state}",
    "physio therapy {city_state}",
    "doctor {city_state}",
    "events in {city_state} this weekend",
    "lawyers in {city_state}",
    "home improvement services {city_state}",
    "florist {city_state}",
    "best cafes in {city_state}",
    "parks in {city_state}",
    "movie theaters {city_state}",
    "top hotels in {city_state}",
    "emergency services {city_state}",
    "car repair {city_state}",
    "pharmacies {city_state}",

    "sushi {city_state}",
    "ramen {city_state}",
    "pizza {city_state}",
    "parks {city_state}",
    "hiking trails {city_state}",
    "pharmacies {city_state}",
    "best cafes {city_state}",
    "movie theaters {city_state}",
    "hamburgers {city_state}",
    "burgers {city_state}",
    "pasta {city_state}",
    "salads {city_state}",
    "vegan food {city_state}",
    "fried chicken {city_state}",
    "ramen {city_state}",
    "tacos {city_state}",
    "sandwiches {city_state}",
    "noodles {city_state}",
    "soups {city_state}",
    "cakes {city_state}",
    "ice cream {city_state}",
    "steak {city_state}",
    "seafood {city_state}",
    "breakfast food {city_state}",
    "brunch {city_state}",
    "desserts {city_state}",
    
    # CITY state order swapped
    "{city_state} sushi",
    "{city_state} ramen",
    "{city_state} pizza",
    "{city_state} parks",
    "{city_state} hiking trails",
    "{city_state} pharmacies",
    "{city_state} best cafes",
    "{city_state} movie theaters",
    "{city_state} hamburgers",
    "{city_state} burgers",
    "{city_state} pasta",
    "{city_state} salads",
    "{city_state} vegan food",
    "{city_state} fried chicken",
    "{city_state} ramen",
    "{city_state} tacos",
    "{city_state} sandwiches",
    "{city_state} noodles",
    "{city_state} soups",
    "{city_state} cakes",
    "{city_state} ice cream",
    "{city_state} steak",
    "{city_state} seafood",
    "{city_state} breakfast food",
    "{city_state} brunch",
    "{city_state} desserts",
    
    # Organization-Based Queries
    "{organization} in {city_state}",
    "contact {organization} in {city}",
    "locations of {organization} in {state}",
    "does {organization} provide home repair services in {city}?",
    "can I book a doctor appointment at {organization} in {state}?",
    "does {organization} offer roof repair in {city_state}?",
    "hours of {organization} in {city}",
    "{organization} reviews in {state}",
    "best rated {organization} in {city_state}",
    "nearest branch of {organization} in {city}",
    
    # Person-Based Queries
    "Where is {person} hosting an event?",
    "Can I meet {person} in {city_state}?",
    "Is {person} available for an appointment in {city}?",
    "Is {person} traveling to {state} next week?",
    "Does {person} have a speech in {city_state}?",
    
    # Mixed and Specialized Queries
    "roof repair near {city}",
    "best sushi in {city_state}",
    "what's the weather forecast for {city}?",
    "who are the top doctors in {city_state}?",
    "restaurants near {city} with good reviews",
    "plumbing services in {city_state}",
    "upcoming events in {city} this weekend",
    "find hiking trails in {city_state}",
    "local electricians in {city_state}",
    "ramen places in {city}",
    "home improvement contractors near {city_state}",
    "best pizza near {city}",
    "does {organization} operate in {city_state}?",
    "find top-rated hospitals in {city_state}",
    "home maintenance services in {city_state}",
    "weather forecast for {city} this weekend",
    "roof repair specialists in {city}",
    "top-rated movie theaters in {city_state}",

    # City-State Queries
    "Best {restaurant} in {city_state}",
    "Top-rated {restaurant} in {city_state}",
    "Affordable {restaurant} in {city_state}",
    "Where to find the best {food} in {city_state}?",
    "Popular {food} places in {city_state}",
    "Top destinations for {travel} in {city_state}",
    "Best deals on {travel} in {city_state}",
    "Where to eat {food} in {city_state}?",
    "What are the most famous {restaurant} in {city_state}?",
    "Top {food} restaurants in {city_state} this weekend",

    # Non-City/State Queries
    "Best {restaurant} in the country",
    "Where to find the best {food} near me?",
    "Top destinations for {travel} this summer",
    "Best deals on {travel} packages",
    "Where to find cheap {travel} options?",
    "Popular {food} dishes in the USA",
    "Best {restaurant} chains in the country",
    "What are the healthiest {food} options?",
    "How to book affordable {travel} for families?",
    "Most popular {restaurant} for takeout",

    # Additional Templates
    "What is the best {food} to eat for dinner?",
    "Where to order {food} online?",
    "Best {restaurant} for date night",
    "Top {travel} websites for booking vacations",
    "Where to find {restaurant} reviews?",
    "What are the top-rated {travel} apps?",
    "Best {restaurant} near tourist attractions",
    "What is the most popular {food} in the USA?",
    "Best deals on {travel} for students",
    "Top {restaurant} for family gatherings",
    "Most affordable {food} delivery services",
    "What are the best {travel} insurance options?",
    "How to find luxury {restaurant} reservations",
    "Where to get authentic {food} near me?",
    "Top {restaurant} for business lunches",
    "How to plan a {travel} adventure?",
    "Best {restaurant} for weekend brunch",
    "What are the most popular {food} trends?",
    "Best {restaurant} for a large group",
    "How to get discounts on {travel} bookings?"

    # Product-Based Queries
    "Where to buy {product} online?",
    "Best deals on {product}",
    "How to repair a {product}?",
    "Latest reviews of {product}",
    "When will the next {product} be released?",
    "Top features of {product}",
    "Is {product} worth buying in 2024?",
    "User reviews of {product}",
    "Alternatives to {product}",
    "What is the price of {product}?",

    # Country-Based Queries
    "How to travel to {country}?",
    "Best tourist destinations in {country}",
    "Top hotels to stay in {country}",
    "Do I need a visa to visit {country}?",
    "Cultural traditions in {country}",
    "What is the official language of {country}?",
    "How to do business in {country}?",
    "What are the top exports of {country}?",
    "Current political situation in {country}",
    "Famous landmarks in {country}",

    # Service-Based Queries
    "How to cancel my {service} subscription?",
    "Is {service} worth the price?",
    "How does {service} compare to competitors?",
    "User reviews of {service}",
    "How to get a discount on {service}?",
    "What are the benefits of {service}?",
    "Best alternatives to {service}",
    "How to troubleshoot issues with {service}?",
    "Does {service} have a free trial?",
    "Is {service} available internationally?",

    # Cars-Based Queries
    "What is the top speed of {car}?",
    "User reviews of {car}",
    "How to finance a {car}?",
    "Fuel efficiency of {car}",
    "How to buy a second-hand {car}?",
    "What are the safety features of {car}?",
    "Maintenance costs of owning a {car}",
    "What is the resale value of {car}?",
    "Is {car} electric or gas-powered?",
    "Best upgrades for {car}",

    # Gadgets-Based Queries
    "What are the best apps for {gadget}?",
    "How to set up a {gadget}?",
    "User reviews of {gadget}",
    "Best accessories for {gadget}",
    "What are the health benefits of using a {gadget}?",
    "What is the battery life of {gadget}?",
    "How to sync {gadget} with my phone?",
    "Alternatives to {gadget}",
    "What are the best productivity apps for {gadget}?",
    "Is {gadget} waterproof?",

    # Stocks-Based Queries
    "What is the latest price of {stock}?",
    "How to buy shares of {stock}?",
    "Is {stock} a good investment in 2024?",
    "What are analysts saying about {stock}?",
    "Current stock performance of {stock}",
    "What is the market cap of {stock}?",
    "How to invest in {stock}?",
    "Latest earnings report of {stock}",
    "What are the dividend yields of {stock}?",
    "How to trade {stock} on the stock market?",

    # Money-Based Queries
    "How to convert {money} to another currency?",
    "Best ways to transfer {money} internationally",
    "What are the risks of using {money}?",
    "How to save {money} for the future?",
    "What is the best way to invest {money}?",
    "How to protect {money} from fraud?",
    "What are the fees for using {money}?",
    "Is {money} safe for online transactions?",
    "Best apps for managing {money}",
    "How to track spending with {money}?",

    # Finance-Based Queries
    "How to invest in a {finance}?",
    "What are the benefits of having a {finance}?",
    "How to calculate the returns on {finance}?",
    "What are the risks of investing in {finance}?",
    "How to get advice for managing my {finance}?",
    "How to apply for a {finance}?",
    "What are the tax benefits of {finance}?",
    "What are the best options for a {finance}?",
    "How to open a {finance} account?",
    "What is the interest rate on {finance}?",

    # sports_term, location_and_landmark, activity_and_event, food_m, transport_and_direction
    # incomplete or misspelled sport/activity names
    "{sports_term} near me", 
    "find {sports_term}", 
    "{sports_term} schedule", 
    "{sports_term} news", 
    "book {sports_term} tickets", 
    "{sports_term} team", 
    "{sports_term} game time", 
    "when is the {sports_term} game", 
    "top {sports_term} players", 
    "local {sports_term} clubs", 
    "where to play {sports_term}", 
    "best {sports_term} venues", 
    "{sports_term} tournament",

    # Generic landmarks and location queries
    "{location_and_landmark} nearby", 
    "famous {location_and_landmark}", 
    "{location_and_landmark} open now", 
    "visit {location_and_landmark}", 
    "{location_and_landmark} directions", 
    "how to get to {location_and_landmark}", 
    "nearest {location_and_landmark}", 
    "{location_and_landmark} address", 
    "top-rated {location_and_landmark}", 
    "{location_and_landmark} hours", 
    "find {location_and_landmark} near me", 
    "{location_and_landmark} entry fee", 
    "best {location_and_landmark} in {city}",

    # Food and dining queries
    "{food_m} place", 
    "find {food_m}", 
    "best {food_m} spot", 
    "{food_m} delivery", 
    "{food_m} open near me", 
    "order {food_m}", 
    "{food_m} deals", 
    "{food_m} options", 
    "{food_m} near me", 
    "{food_m} reservation", 
    "top-rated {food_m} restaurants", 
    "{food_m} reviews", 
    "{food_m} menu", 
    "popular {food_m} dishes", 
    "where to eat {food_m}",

    # activities_and_events
    "{activity_and_event} tickets", 
    "nearest {activity_and_event}", 
    "{activity_and_event} today", 
    "upcoming {activity_and_event}", 
    "book {activity_and_event}", 
    "{activity_and_event} in {city}", 
    "find {activity_and_event}", 
    "{activity_and_event} schedule", 
    "{activity_and_event} near me", 
    "top-rated {activity_and_event} venues", 
    "{activity_and_event} details", 
    "how to attend {activity_and_event}", 
    "{activity_and_event} location", 
    "{activity_and_event} opening hours",

    # Single-word incomplete or ambiguous queries (standalone)
    # Sports and Games (single or incomplete)
    "footbal", "baske", "golf", "sush", "pizz", "zoo", "conc", "direc", 
    "theate", "stadiu", "brunc", "tourn", "parad", "swimmin", "train", "taxi", 
    "game", "meet", "mountain", "beac", "lake", "forest", "ligh", "restauran", 
    "parki", "stor", "monumen", "aren", "boardwal",
    # Locations and Landmarks (single or incomplete)
    "statue", "museum", "plaza", "zoo", "church", "theater", "stadium", "mountain", 
    "park", "lake", "beach", "river", "palace", "cathedra", "mansion", "monument", 
    "temple", "observato", "canyon", "garden", "conservato", "boardwal", "forest", 
    "pier", "lighthouse", "arena", "campgroun", "arch", "reservoi", "dam", "fountai", 
    "waterfal", "galleri", "amphitheate", "sculptur", "trail", "cliff", "tower", "islan",
    # Activities and Events (single or incomplete)
    "conc", "exhib", "meet", "parad", "festi", "tourn", "game", "sho", "even", "gala", 
    "confere", "seminar", "webina", "worksho", "lectur", "symposiu", "screenin", 
    "rall", "celebratio", "ceremon", "get-togethe", "perfor", "gatherin", "competitio", 
    "maratho", "speec", "workout", "exercis", "demonstratio", "ceremony", "readin", 
    "daytrip", "lectur", "social", "activit", "performanc", "worksho", "openin", 
    "finale", "comedy", "poetr", "talent", "match",
    # Restaurants and Food Types (single or incomplete)
    "sush", "pizz", "ramen", "bbq", "vega", "steak", "taco", "burg", "pasta", "brunc", 
    "desse", "drink", "grill", "bake", "buffet", "sandwich", "noodle", "cafe", 
    "taver", "gastro", "bistro", "deli", "saloo", "barbecue", "snack", "confectio", 
    "pub", "salad", "cuisine", "fries", "wings", "pantr", "meatbal", "sub", "omel", 
    "crepe", "wrap", "beverag", "dessert", "smoothie", "juice", "shake", "frappe", "coffee",
    # Transport and Directions (single or incomplete)
    "direc", "map", "bus", "train", "car", "park", "taxi", "subwa", "fly", "plane", 
    "ticke", "pass", "ferr", "bicycl", "scoote", "shuttl", "walkin", "rideshar", 
    "transi", "toll", "metr", "road", "route", "stop", "junctio", "termina", "highwa", 
    "pathwa", "drivewa", "loop", "intersectio", "trailhead", "tub", "sidestro", 
    "crosswal", "rout", "navigatio", "crossing", "pave", "deck", "lane",
    # Technology and Gadgets (single or incomplete)
    "lapt", "smartphon", "comput", "tablet", "earbuds", "bluetooth", "charg", "cabl", 
    "headset", "monitor", "consol", "keyboard", "drive", "storag", "gaming", "mouse", 
    "projector", "flashdriv", "powerban", "adapter", "webcam", "router", "modem", 
    "camcorder", "printer", "copier", "recorde", "remote", "surge", "extend", "plug", 
    "portabl", "backu", "networ", "recharge", "uplo", "downlo", "strea", "screencas", 
    "googl", "apple", "micros", "andr",
]

In [13]:
len(templates)

570

In [33]:
PERSON_ENTITY = "{person}"
ORG_ENTITY = "{organization}"
CITY_ENTITY = "{city}"
STATE_ENTITY = "{state}"
CITY_STATE_ENTITY = "{city_state}"
PRODUCT_ENTITY = "{product}"
COUNTRY_ENTITY = "{country}"
SERVICE_ENTITY = "{services}"
CAR_ENTITY = "{car}"
GADGET_ENTITY = "{gadget}"
STOCK_ENTITY = "{stock}"
MONEY_ENTITY = "{money}"
FINANCE_ENTITY = "{finance}"
TRAVEL_ENTITY = "{travel}"
FOOD_ENTITY = "{food}"
RESTAURANT_ENTITY = "{restaurant}"
SPORTS_TERMS_MISSING_ENTITY = "{sports_term}"
LOCATIONS_AND_LANDMARKS_ENTITY = "{location_and_landmark}"
ACTIVTIES_AND_EVENTS_ENTITY = "{activity_and_event}"
FOOD_MISSING_ENTITY = "{food_m}"
TRANSPORT_AND_DIRECTIONS_ENTITY = "{transport_and_direction}"


def detect_entity(entity_name, template):
    return entity_name in template

def tokenize(text):
    # Use regular expression to split words while keeping punctuation as separate tokens
    return re.findall(r'\w+|[^\w\s]', text)

# Tokenize the query and generate corresponding NER labels
def tokenize_and_label(query, city, state, city_state, organization, person):
    tokens = tokenize(query)  # Tokenize the query using the improved function
    ner_labels = [0] * len(tokens)  # Initialize all labels as "O" (outside any entity)
    
    # Label city_state entity
    if city_state:
        city_state_tokens = tokenize(city_state)
        start_idx = find_token_index(tokens, city_state_tokens)
        if start_idx is not None:
            ner_labels[start_idx] = 9  # CSB-LOC (beginning of city_state)
            for i in range(1, len(city_state_tokens)):
                ner_labels[start_idx + i] = 10  # CSI-LOC (inside city_state)

    # Label city entity
    if city:
        city_tokens = tokenize(city)
        start_idx = find_token_index(tokens, city_tokens)
        if start_idx is not None:
            ner_labels[start_idx] = 5  # CB-LOC (beginning of city)
            for i in range(1, len(city_tokens)):
                ner_labels[start_idx + i] = 6  # CI-LOC (inside city)
    
    # Label state entity
    if state:
        state_tokens = tokenize(state)
        start_idx = find_token_index(tokens, state_tokens)
        if start_idx is not None:
            ner_labels[start_idx] = 7  # SB-LOC (beginning of state)
            for i in range(1, len(state_tokens)):
                ner_labels[start_idx + i] = 8  # SI-LOC (inside state)

    # Label organization entity
    if organization:
        org_tokens = tokenize(organization)
        start_idx = find_token_index(tokens, org_tokens)
        if start_idx is not None:
            ner_labels[start_idx] = 3  # B-ORG (beginning of organization)
            for i in range(1, len(org_tokens)):
                ner_labels[start_idx + i] = 4  # I-ORG (inside organization)

    # Label person entity
    if person:
        person_tokens = tokenize(person)
        start_idx = find_token_index(tokens, person_tokens)
        if start_idx is not None:
            ner_labels[start_idx] = 1  # B-PER (beginning of person)
            for i in range(1, len(person_tokens)):
                ner_labels[start_idx + i] = 2  # I-PER (inside person)
    
    return tokens, ner_labels

# Function to find the starting index of an entity's tokens in the query tokens
def find_token_index(tokens, entity_tokens):
    for i in range(len(tokens) - len(entity_tokens) + 1):
        if tokens[i:i + len(entity_tokens)] == entity_tokens:
            return i
    return None

def generate_queries(templates, n_queries=10000):
    cnt = 0
    queries_with_labels = []
    query_counter = Counter()
    while cnt < n_queries:
        if (cnt %10000) == 0:
            print(f"complted generating {cnt} queries")
        template = random.choice(templates)
        # print(template)
        person, organization, city, state, city_state = (None,) * 5
        product, country, service, car, gadget, stock, money, finance, travel, food, restaurant = (None,) * 11
        sports_term, location_and_landmark, activity_and_event, food_m, transport_and_direction = (None,) * 5

        if detect_entity(PERSON_ENTITY, template):
            person=get_random_choice_from_list(persons)
        if detect_entity(ORG_ENTITY, template):
            organization = get_random_choice_from_list(organizations)
        if detect_entity(PRODUCT_ENTITY, template):
            product = get_random_choice_from_list(products)
        if detect_entity(COUNTRY_ENTITY, template):
            country = get_random_choice_from_list(countries)
        if detect_entity(COUNTRY_ENTITY, template):
            service = get_random_choice_from_list(services)
        if detect_entity(CAR_ENTITY, template):
            car = get_random_choice_from_list(cars)
        if detect_entity(GADGET_ENTITY, template):
            gadget = get_random_choice_from_list(gadgets)
        if detect_entity(STOCK_ENTITY, template):
            stock = get_random_choice_from_list(stocks)
        if detect_entity(MONEY_ENTITY, template):
            money = get_random_choice_from_list(moneys)
        if detect_entity(FINANCE_ENTITY, template):
            finance = get_random_choice_from_list(finances)
        if detect_entity(TRAVEL_ENTITY, template):
            travel = get_random_choice_from_list(travels)
        if detect_entity(FOOD_ENTITY, template):
            food = get_random_choice_from_list(foods)
        if detect_entity(RESTAURANT_ENTITY, template):
            restaurant = get_random_choice_from_list(restaurants)
        if detect_entity(SPORTS_TERMS_MISSING_ENTITY, template):
            sports_term = get_random_choice_from_list(sports_terms_missing)
        if detect_entity(LOCATIONS_AND_LANDMARKS_ENTITY, template):
            location_and_landmark = get_random_choice_from_list(locations_and_landmarks)
        if detect_entity(ACTIVTIES_AND_EVENTS_ENTITY, template):
            activity_and_event = get_random_choice_from_list(activities_and_events)
        if detect_entity(FOOD_MISSING_ENTITY, template):
            food_m = get_random_choice_from_list(food_missing)
        if detect_entity(TRANSPORT_AND_DIRECTIONS_ENTITY, template):
            transport_and_direction = get_random_choice_from_list(transport_and_directions)

        if detect_entity(CITY_ENTITY, template):
            city=get_sample_from_cities(city_info, city_weights, actual_threshold=0.7)
        if detect_entity(STATE_ENTITY, template):
            state=get_sample_from_states(state_info, actual_threshold=0.5)
        if detect_entity(CITY_STATE_ENTITY, template):
            city_state=get_sample_from_cities_and_states(city_state_code_info, city_state_name_info, state_code_threshold=0.8)
        
        query = template.format(person=person,
                                organization=organization,
                                city=city,
                                state=state,
                                city_state=city_state,
                                product=product,
                                country=country,
                                service=service,
                                car=car,
                                gadget=gadget,
                                stock=stock,
                                money=money,
                                finance=finance,
                                travel=travel,
                                food=food,
                                restaurant=restaurant,
                                sports_term=sports_term,
                                location_and_landmark=location_and_landmark,
                                activity_and_event=activity_and_event,
                                food_m=food_m,
                                transport_and_direction=transport_and_direction
                               )
        tokens, ner_labels = tokenize_and_label(query, city, state, city_state, organization, person)
        if query_counter.get(query, 0) == 0:
            queries_with_labels.append((query, tokens, ner_labels))
            query_counter.update([query])
            cnt += 1
    return queries_with_labels

In [34]:
queries_with_labels = generate_queries(templates, n_queries=300000)

complted generating 0 queries
complted generating 10000 queries
complted generating 10000 queries
complted generating 10000 queries
complted generating 20000 queries
complted generating 30000 queries
complted generating 40000 queries
complted generating 40000 queries
complted generating 40000 queries
complted generating 40000 queries
complted generating 40000 queries
complted generating 40000 queries
complted generating 40000 queries
complted generating 40000 queries
complted generating 40000 queries
complted generating 50000 queries
complted generating 50000 queries
complted generating 50000 queries
complted generating 60000 queries
complted generating 60000 queries
complted generating 60000 queries
complted generating 60000 queries
complted generating 60000 queries
complted generating 60000 queries
complted generating 60000 queries
complted generating 60000 queries
complted generating 60000 queries
complted generating 60000 queries
complted generating 60000 queries
complted generatin

In [35]:
len(queries_with_labels)

300000

In [36]:
# queries_with_labels[:10]
df_ner_examples = pd.DataFrame(queries_with_labels, columns=['query', 'tokens', 'ner_tags'])
df_ner_examples

Unnamed: 0,query,tokens,ner_tags
0,Best upgrades for Volkswagen Golf,"[Best, upgrades, for, Volkswagen, Golf]","[0, 0, 0, 0, 0]"
1,top-rated conc venues,"[top, -, rated, conc, venues]","[0, 0, 0, 0, 0]"
2,nearest arena,"[nearest, arena]","[0, 0]"
3,saloo,[saloo],[0]
4,tourn in San Diego,"[tourn, in, San, Diego]","[0, 0, 5, 6]"
...,...,...,...
299995,"movie theaters Albany, OR","[movie, theaters, Albany, ,, OR]","[0, 0, 9, 10, 10]"
299996,home improvement Meridian,"[home, improvement, Meridian]","[0, 0, 5]"
299997,ice cream The Villages FL,"[ice, cream, The, Villages, FL]","[0, 0, 9, 10, 10]"
299998,weather Caldwell ID,"[weather, Caldwell, ID]","[0, 9, 10]"


In [37]:
df_ner_examples['ner_tags'].apply(lambda tags: len([tag for tag in tags if tag > 4])).value_counts()

ner_tags
3    107543
1     73502
2     57480
4     48084
5      7978
0      4194
6       995
7       134
8        90
Name: count, dtype: int64

In [38]:
label_map

{0: 'O',
 1: 'B-PER',
 2: 'I-PER',
 3: 'B-ORG',
 4: 'I-ORG',
 5: 'B-CITY',
 6: 'I-CITY',
 7: 'B-STATE',
 8: 'I-STATE',
 9: 'B-CITYSTATE',
 10: 'I-CITYSTATE'}

In [39]:
df_ner_examples['query'].value_counts()

query
Best upgrades for Volkswagen Golf                     1
Rochester Philharmonic Orchestra in Haverhill, MA     1
Where to find the best noodles in Springfield, MA?    1
home improvement services Kenosha WI                  1
home services Ahwatukee Foothills                     1
                                                     ..
burgers Peoria AZ                                     1
real estate agents nashville                          1
florist Largo, FL                                     1
hours of Newark Data Solutions in Cicero              1
Flagstaff, Arizona salads                             1
Name: count, Length: 300000, dtype: int64

In [40]:
df_ner_examples.to_csv("../data/df_ner_examples_v3.csv", index=False)

In [41]:
# useful for post processing to standardize the city names
def build_lookup(dataframe):
    # Initialize an empty dictionary for the lookup
    lookup = {}
    
    # Iterate over each row in the DataFrame
    for index, row in dataframe.iterrows():
        city_name = row['city_name']
        alternate_names = row['alternate_names']
        
        # Iterate over the list of alternate names and map them to the city_name
        for alt_name in alternate_names:
            lookup[alt_name.lower()] = city_name  # Convert alternate names to lowercase for consistency
    
    return lookup

city_alternate_to_city_lkp = build_lookup(city_states_data)

In [42]:
len(city_alternate_to_city_lkp)

1356

In [43]:
# !python -m pip install onnxruntime

In [44]:
# !python -m pip freeze| grep  onnxruntime

In [45]:
# !mkdir ../models

In [46]:
import onnxruntime as ort
import numpy as np
from transformers import AutoTokenizer, BertTokenizer

# Download the ONNX model
# model_url = "https://huggingface.co/Xenova/bert-base-NER/resolve/main/onnx/model_quantized.onnx"
# model_url = "https://huggingface.co/Mozilla/distilbert-NER-LoRA/resolve/main/onnx/model_quantized.onnx"
model_url = "https://huggingface.co/Mozilla/distilbert-uncased-NER-LoRA/resolve/main/onnx/model_quantized.onnx"
# model_url = "https://huggingface.co/chidamnat2002/distilbert-uncased-NER-LoRA/resolve/main/onnx/model_quantized.onnx"
# model_path = "../models/distilbert-NER-LoRA.onnx"
model_path = "../models/distilbert-uncased-NER-LoRA.onnx"

# Download the ONNX model if not already present
response = requests.get(model_url)
with open(model_path, 'wb') as f:
    f.write(response.content)

# Load the ONNX model using ONNX Runtime
session = ort.InferenceSession(model_path)

# Load the tokenizer (assuming it's based on BERT)
# tokenizer = BertTokenizer.from_pretrained("Mozilla/distilbert-NER-LoRA")
tokenizer = AutoTokenizer.from_pretrained("Mozilla/distilbert-uncased-NER-LoRA")

In [47]:
def compute_model_inputs_and_outputs(session, tokenizer, query):
    # Tokenize the input
    # inputs = tokenizer(query, return_tensors="np", truncation=True, padding=True)
    inputs = tokenizer(query, return_tensors="np", truncation=True, padding='max_length', max_length=64)
    # is_split_into_words=True,
                                          # truncation=True,
                                          # padding='max_length',
                                          # max_length=64
    
    # The ONNX model expects 'input_ids', 'attention_mask', and 'token_type_ids'
    # Convert all necessary inputs to numpy arrays and prepare the input feed
    input_feed = {
        'input_ids': inputs['input_ids'].astype(np.int64),
        'attention_mask': inputs['attention_mask'].astype(np.int64),
        # 'token_type_ids': inputs['token_type_ids'].astype(np.int64)  # Some models might not need this; check if it's really required
    }
    
    # Run inference with the ONNX model
    outputs = session.run(None, input_feed)
    # print(outputs)
    return inputs, outputs


In [48]:
label_map

{0: 'O',
 1: 'B-PER',
 2: 'I-PER',
 3: 'B-ORG',
 4: 'I-ORG',
 5: 'B-CITY',
 6: 'I-CITY',
 7: 'B-STATE',
 8: 'I-STATE',
 9: 'B-CITYSTATE',
 10: 'I-CITYSTATE'}

In [49]:
## With Xenova/bert-base-NER
# Number of examples = 349
# #hits = 135; #hit rate = 0.3868194842406877

## After finetuning the Mozilla/distilbert-NER-LoRA
#hits = 220; #hit rate = 0.6303724928366762

## After finetuning the chidamnat2002/distilbert-uncased-NER-LoRA
#hits = 207; #hit rate = 0.5931232091690545

## After finetuning the Mozilla/distilbert-uncased-NER-LoRA
#hits = 252; #hit rate = 0.7220630372492837

In [50]:
# len(missing_locations)

In [51]:
# print(missing_locations)

#### Looking into CONLL 2003 dataset

In [52]:
from datasets import load_dataset, Dataset
import re

# Load the CoNLL-2003 dataset
dataset = load_dataset("conll2003")

loc_examples = dataset

In [53]:
# dataset['train'].to_pandas()

In [54]:
# dataset['train']

In [55]:
synthetic_loc_dataset = Dataset.from_pandas(df_ner_examples.drop('query', axis=1))
print(synthetic_loc_dataset)

print(synthetic_loc_dataset[0])

Dataset({
    features: ['tokens', 'ner_tags'],
    num_rows: 300000
})
{'tokens': ['Best', 'upgrades', 'for', 'Volkswagen', 'Golf'], 'ner_tags': [0, 0, 0, 0, 0]}


In [56]:
# loc_dataset = dataset['train'].filter(lambda example: 5 in example['ner_tags'])
loc_dataset = dataset['train']
loc_dataset_filtered = loc_dataset.remove_columns(['pos_tags', 'chunk_tags'])

# Set the format to ensure the order is 'id', 'tokens', and 'ner_tags'
loc_dataset_filtered[0]

{'id': '0',
 'tokens': ['EU',
  'rejects',
  'German',
  'call',
  'to',
  'boycott',
  'British',
  'lamb',
  '.'],
 'ner_tags': [3, 0, 7, 0, 0, 0, 7, 0, 0]}

In [57]:
# loc_dataset_filtered[-1]

In [58]:
from datasets import concatenate_datasets

from datasets import Sequence, ClassLabel, Value

# Step 1: Get the full feature schema from synthetic_loc_dataset
features = synthetic_loc_dataset.features

# Step 2: Update the 'ner_tags' feature to use ClassLabel from loc_dataset_filtered
# features['ner_tags'] = Sequence(feature=ClassLabel(names=loc_dataset_filtered.features['ner_tags'].feature.names))
features['ner_tags'] = Sequence(feature=ClassLabel(names=list(label_map.values())))

# Step 3: Cast synthetic_loc_dataset to the updated feature schema
synthetic_loc_dataset = synthetic_loc_dataset.cast(features)

# Check the updated features to confirm
print(synthetic_loc_dataset.features)

# Now concatenate the datasets
# combined_dataset = concatenate_datasets([loc_dataset_filtered, synthetic_loc_dataset])

# Verify the combined dataset
print(synthetic_loc_dataset[0])


Casting the dataset:   0%|          | 0/300000 [00:00<?, ? examples/s]

{'tokens': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'ner_tags': Sequence(feature=ClassLabel(names=['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-CITY', 'I-CITY', 'B-STATE', 'I-STATE', 'B-CITYSTATE', 'I-CITYSTATE'], id=None), length=-1, id=None)}
{'tokens': ['Best', 'upgrades', 'for', 'Volkswagen', 'Golf'], 'ner_tags': [0, 0, 0, 0, 0]}


In [59]:
# ClassLabel(names=loc_dataset_filtered.features['ner_tags'].feature.names)

In [60]:
# ClassLabel(names=list(label_map.values()))

In [61]:
len(synthetic_loc_dataset)

300000

In [62]:
synthetic_loc_dataset[3]

{'tokens': ['saloo'], 'ner_tags': [0]}

In [63]:
synthetic_loc_dataset = synthetic_loc_dataset.map(
    lambda example, idx: {'id': idx},  # Assign running count as the new 'id'
    with_indices=True  # Ensures we get an index for each example
)

Map:   0%|          | 0/300000 [00:00<?, ? examples/s]

In [64]:
synthetic_loc_dataset.to_pandas()

Unnamed: 0,tokens,ner_tags,id
0,"[Best, upgrades, for, Volkswagen, Golf]","[0, 0, 0, 0, 0]",0
1,"[top, -, rated, conc, venues]","[0, 0, 0, 0, 0]",1
2,"[nearest, arena]","[0, 0]",2
3,[saloo],[0],3
4,"[tourn, in, San, Diego]","[0, 0, 5, 6]",4
...,...,...,...
299995,"[movie, theaters, Albany, ,, OR]","[0, 0, 9, 10, 10]",299995
299996,"[home, improvement, Meridian]","[0, 0, 5]",299996
299997,"[ice, cream, The, Villages, FL]","[0, 0, 9, 10, 10]",299997
299998,"[weather, Caldwell, ID]","[0, 9, 10]",299998


In [65]:
synthetic_loc_dataset[-1]

{'tokens': ['Flagstaff', ',', 'Arizona', 'salads'],
 'ner_tags': [9, 10, 10, 0],
 'id': 299999}

In [66]:
synthetic_loc_dataset.to_parquet("../data/synthetic_loc_dataset_v3.parquet")

Creating parquet from Arrow format:   0%|          | 0/300 [00:00<?, ?ba/s]

36125163

In [67]:
from transformers import AutoTokenizer, AutoModelForTokenClassification
from transformers import pipeline

tokenizer = AutoTokenizer.from_pretrained("Mozilla/distilbert-uncased-NER-LoRA")
model = AutoModelForTokenClassification.from_pretrained("Mozilla/distilbert-uncased-NER-LoRA")

nlp = pipeline("ner", model=model, tokenizer=tokenizer)
example = "New York"

ner_results = nlp(example)
print(ner_results)


Hardware accelerator e.g. GPU is available in the environment, but no `device` argument is passed to the `Pipeline` object. Model will be on CPU.


[{'entity': 'B-STATE', 'score': np.float32(0.827376), 'index': 1, 'word': 'new', 'start': 0, 'end': 3}, {'entity': 'I-STATE', 'score': np.float32(0.69074583), 'index': 2, 'word': 'york', 'start': 4, 'end': 8}]
