fulltext_search/search_sharded.py (102 lines of code) (raw):

import argparse import json import sys import time import requests from datasets import load_dataset def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--input_dataset", type=str, default="HuggingFaceTB/bisac_expanded_final" ) parser.add_argument("--n_pages", type=int, default=2000) parser.add_argument( "--output_dataset", type=str, default="HuggingFaceTB/bisac_boosted_new_index_2000", ) parser.add_argument("--shard", type=int, required=True) parser.add_argument("--num_shards", type=int, required=True) return parser.parse_args() # wait until the server is up while True: try: requests.post( "http://127.0.0.1:9308/search", data='{"index": "fineweb", "query": {"match": {"content": "ping"}}}', ) break except requests.exceptions.ConnectionError: time.sleep(10) pass args = get_args() data = load_dataset( args.input_dataset, split="train", cache_dir="/scratch/cosmo/.cache" ) data = data.filter(lambda x, i: i % args.num_shards == args.shard, with_indices=True) data = data.select_columns(["top_category", "subcategory", "subtopic"]) def run_query(query, n_pages): while True: try: max_pages = 4_000 response = requests.post( "http://127.0.0.1:9308/search", data=json.dumps( { "index": "fineweb", "size": n_pages, "query": query, "max_matches": max_pages, } ), timeout=1000, ) if response.status_code != 200: print(response.text, file=sys.stderr) time.sleep(5) continue else: hits = response.json()["hits"]["hits"] return hits except requests.exceptions.ConnectionError as e: print(e, file=sys.stderr) time.sleep(5) continue def search_topic(sample): top_category = sample["top_category"][0].strip() subcategory = sample["subcategory"][0].strip() subtopic = sample["subtopic"][0].strip() for c in ["!", '"', "$", "'", "(", ")", "/", "<", "@", "\\", "^", "|", "~"]: top_category = top_category.replace(c, " ") subcategory = subcategory.replace(c, " ") subtopic = subtopic.replace(c, " ") # boosting the IDF score of subtopic tokens boosted_subtopic = " ".join([w + "^2" for w in subtopic.split()]) match_query = " ".join([top_category, subcategory, subtopic]) boosted_query = " ".join([top_category, subcategory, boosted_subtopic]) boosted_hits = run_query({"query_string": boosted_query}, args.n_pages) print(f"Boosted hits: {len(boosted_hits)} for {boosted_query}", file=sys.stderr) if len(boosted_hits) < args.n_pages: match_hits = run_query( {"match": {"content": match_query}}, args.n_pages + len(boosted_hits) ) print(f"Match hits: {len(match_hits)} for {match_query}", file=sys.stderr) else: match_hits = [] hit_ids = set() hits = [] for hit in boosted_hits + match_hits: if hit["_id"] not in hit_ids: hits.append(hit) hit_ids.add(hit["_id"]) hits = hits[: args.n_pages] results = { "top_category": sample["top_category"] * len(hits), "subcategory": sample["subcategory"] * len(hits), "subtopic": sample["subtopic"] * len(hits), "topic_hits": hits, "num_hits": [len(hits)] * len(hits), } return results data = data.map(search_topic, batched=True, batch_size=1, num_proc=2) data.push_to_hub( f"{args.output_dataset}_{args.shard}", private=True, max_shard_size="4096MB" )