msmarco-v2-vector/_tools/parse_documents.py (68 lines of code) (raw):

import json import sys import numpy import vg from datasets import DownloadMode, load_dataset DATASET_NAME: str = f"Cohere/msmarco-v2-embed-english-v3" DATASET_DL_PROCS: int = 6 OUTPUT_FILENAME: str = "cohere-documents" DEFAULT_MAX_DOCS = -1 TOTAL_DOCS = 138364198 MAX_DOCS_PER_FILE = 3_000_000 TOTAL_PAGES = 47 PROGRESS_EVERY = 100 def progress_bar(count, total): bar_length = 100 filled_length = int(round(bar_length * count / float(total))) percentage = round(100.0 * count / float(total), 1) bar = "=" * filled_length + "-" * (bar_length - filled_length) sys.stdout.write("[{}] {}{} ... {:,}/{:,}\r".format(bar, percentage, "%", count, total)) sys.stdout.flush() def output_pages(start_page, end_page): for page in range(start_page, end_page + 1): start_index = (page - 1) * MAX_DOCS_PER_FILE end_index = start_index + MAX_DOCS_PER_FILE if end_index > TOTAL_DOCS: end_index = TOTAL_DOCS output_filename = f"{OUTPUT_FILENAME}-{page:02d}.json" print(f"Outputing page {page} documents to {output_filename}") with open(output_filename, "w") as documents_file: output_documents(documents_file, start_index, end_index) def output_documents(docs_file, start_index, end_index): doc_count = 0 dataset_size = end_index - start_index print(f"Parsing {dataset_size} documents from {DATASET_NAME} [{start_index}:{end_index}]") docs = load_dataset( DATASET_NAME, split=f"train[{start_index}:{end_index}]", num_proc=DATASET_DL_PROCS, download_mode=DownloadMode.REUSE_DATASET_IF_EXISTS, ) progress_bar(doc_count, dataset_size) for doc in docs: normalized_embed = vg.normalize(numpy.array(doc["emb"])).tolist() docs_file.write( json.dumps( {"docid": doc["_id"], "title": doc["title"], "text": doc["text"], "emb": normalized_embed}, ensure_ascii=True, ) ) docs_file.write("\n") doc_count += 1 if doc_count % PROGRESS_EVERY == 0: progress_bar(doc_count, dataset_size) print(f"Wrote {doc_count} documents to output file.") def parse_arguments(): if len(sys.argv) >= 3: return (DEFAULT_MAX_DOCS, int(sys.argv[1]), int(sys.argv[2])) if len(sys.argv) >= 2: return (int(sys.argv[1]), 1, TOTAL_PAGES) return (DEFAULT_MAX_DOCS, 1, TOTAL_PAGES) if __name__ == "__main__": (max_documents, start_page, end_page) = parse_arguments() if max_documents == DEFAULT_MAX_DOCS: output_pages(start_page, end_page) else: print("Outputing documents to {}.json".format(OUTPUT_FILENAME)) with open(f"{OUTPUT_FILENAME}.json", "w") as documents_file: output_documents(documents_file, 0, max_documents)