fulltext_search/index_docs.py (117 lines of code) (raw):
import json
import time
import sys
import random
import requests
from datasets import load_dataset
def insert_batch(batch):
ndjson = ""
index_name = f"fineweb{random.randint(0, 63)}"
for text, _id, url, language_score, token_count in zip(
batch["text"],
batch["id"],
batch["url"],
batch["language_score"],
batch["token_count"],
):
doc = {
"insert": {
"index": index_name,
"_id": _id.split(":")[-1].strip(">"),
"doc": {
"content": text,
"fw_id": _id.split(":")[-1].strip(">"),
"url": url,
"language_score": language_score,
"token_count": token_count,
},
}
}
ndjson += json.dumps(doc) + "\n"
response = None
while response is None:
try:
response = requests.post(
"http://127.0.0.1:9308/bulk",
headers={"Content-Type": "application/x-ndjson"},
data=ndjson,
)
except requests.exceptions.ConnectionError as e:
print(e, file=sys.stderr)
time.sleep(1)
pass
return {"response": [response.status_code]}
def main():
sql_url = "http://127.0.0.1:9308/sql?mode=raw"
print("Removing table", file=sys.stderr)
while True:
try:
requests.post(sql_url, data={"query": "drop table if exists fineweb"})
break
except requests.exceptions.ConnectionError as e:
print(e, file=sys.stderr)
time.sleep(5)
pass
print("Creating table", file=sys.stderr)
for i in range(64):
response = requests.post(
sql_url, data={"query": f"drop table if exists fineweb{i}"}
)
print(response.text, file=sys.stderr)
local_query = f"create table fineweb{i}(content text, fw_id string, url string, language_score float, token_count int) charset_table='non_cjk' stopwords='en' morphology='stem_en'"
response = requests.post(sql_url, data={"query": local_query})
print(response.text, file=sys.stderr)
distributed_query = "create table fineweb type='distributed'"
for i in range(64):
distributed_query += f" local='fineweb{i}'"
response = requests.post(sql_url, data={"query": distributed_query})
print(response.text, file=sys.stderr)
for dump in ["CC-MAIN-2024-10", "CC-MAIN-2023-50"]:
print("Loading dataset", file=sys.stderr)
dataset = load_dataset(
"HuggingFaceFW/fineweb",
dump,
split="train",
num_proc=64,
cache_dir="/scratch/cosmo/.cache",
)
dataset = dataset.select_columns(
["text", "id", "url", "language_score", "token_count"]
)
dataset = dataset.map(
insert_batch,
batched=True,
batch_size=10000,
remove_columns=["text", "id", "url", "language_score", "token_count"],
num_proc=64,
)
for _ in dataset:
pass
time.sleep(30)
for i in range(64):
print(f"Optimizing table fineweb{i}", file=sys.stderr)
response = requests.post(
sql_url,
data={"query": f"FLUSH TABLE fineweb{i}"},
timeout=600,
)
print(response.text, file=sys.stderr)
response = requests.post(
sql_url,
data={"query": f"OPTIMIZE TABLE fineweb{i} OPTION cutoff=16, sync=1"},
timeout=600,
)
print(response.text, file=sys.stderr)
response = requests.post(
sql_url,
data={"query": f"FREEZE fineweb{i}"},
timeout=600,
)
print(response.text, file=sys.stderr)
response = requests.post(
"http://127.0.0.1:9308/search",
data='{"index":"fineweb","query":{"match":{"*":"hello world"}}}',
)
print(response.text, file=sys.stderr)
# print("Backing up the index", file=sys.stderr)
# time.sleep(30)
# response = requests.post(
# sql_url,
# data={"query": "BACKUP TO /tmp/backups"},
# )
# print(response.text, file=sys.stderr)
if __name__ == "__main__":
main()