toolkits/pretrain_data_preprocessing/preprocess_data_megatron.py (360 lines of code) (raw):
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Processing large data for pretraining."""
import argparse
import math
import json
import os
import sys
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
os.path.pardir)))
import time
import gzip
import glob
import torch
import numpy as np
#import ftfy
import multiprocessing
try:
import nltk
nltk_available = True
except ImportError:
nltk_available = False
from megatron.core.datasets import indexed_dataset
from megatron_patch.tokenizer import build_tokenizer
manager = multiprocessing.Manager()
token_count_queue = multiprocessing.Queue()
class IdentitySplitter(object):
def tokenize(self, *text):
return text
class Encoder(object):
def __init__(self, args):
self.args = args
self.total_token_count = 0 # add total_token_count
def initializer(self):
# Use Encoder class as a container for global data
Encoder.tokenizer = build_tokenizer(self.args)
if self.args.split_sentences:
if not nltk_available:
print("NLTK is not available to split sentences.")
exit()
if os.environ.get("NLTK_DATA"):
library = os.path.join(os.environ.get("NLTK_DATA"), "tokenizers", "punkt", f"{self.args.lang}.pickle")
url = f"file:{library}"
else:
library = os.path.join("tokenizers", "punkt", f"{self.args.lang}.pickle")
url = f"nltk:{library}"
splitter = nltk.load(url)
Encoder.splitter = splitter
else:
Encoder.splitter = IdentitySplitter()
def split(self, json_line):
data = json.loads(json_line)
output = {}
total_token_count = 0
for key in self.args.json_keys:
text = data[key]
max_len = 1000000
tokens_list = [Encoder.splitter.tokenize(text[i:i+max_len]) for i in range(0, len(text), max_len)]
output[key] = [tokens for partial in tokens_list for tokens in partial]
total_token_count += sum(len(tokens) for partial in tokens_list for tokens in partial)
return json.dumps(output), len(json_line), total_token_count
def encode(self, json_line):
try:
data = json.loads(json_line)
except:
return {}, {}, 0, 0
ids = {}
lens = {}
for key in self.args.json_keys:
text = data[key]
#text = ftfy.fix_text(text)
if isinstance(text, list):
sentences = text
else:
sentences = [text]
doc_ids = []
sentence_lens = []
for sentence in sentences:
if self.args.patch_tokenizer_type in ["DeepSeekV2Tokenizer", "Qwen3Tokenizer", "Qwen2Tokenizer", "LLama3Tokenizer", "LLama2Tokenizer"]:
sentence_ids = Encoder.tokenizer.tokenizer(sentence, add_special_tokens=False)['input_ids']
elif self.args.patch_tokenizer_type == "GPT2BPETokenizer":
sentence_ids = Encoder.tokenizer.tokenize(sentence)
else:
sentence_ids = Encoder.tokenizer(sentence, add_special_tokens=False)['input_ids']
if not sentence_ids:
print(f"tokenizer error sentence_ids is empty :\n {text} \n")
continue
if max(sentence_ids) >= Encoder.tokenizer.vocab_size:
print(f"tokenizer error max(sentence_ids) >= Encoder.tokenizer.vocab_size :\n {text}\n {max(sentence_ids)}")
continue
if len(sentence_ids) > 0:
self.total_token_count += len(sentence_ids) # increase total token
doc_ids.extend(sentence_ids)
sentence_lens.append(len(sentence_ids))
if len(doc_ids) > 0 and self.args.append_eod:
doc_ids.append(Encoder.tokenizer.eod)
sentence_lens[-1] += 1
ids[key] = doc_ids
lens[key] = sentence_lens
return ids, lens, len(json_line), self.total_token_count
class Partition(object):
def __init__(self, args, workers):
self.args = args
self.workers = workers
def print_processing_stats(self, count, proc_start, total_bytes_processed, total_token_count):
if count % self.args.log_interval == 0:
current = time.time()
elapsed = current - proc_start
mbs = total_bytes_processed/elapsed/1024/1024
print(f"Processed {count} documents",
f"({count/elapsed} docs/s, {mbs} MB/s). Total tokens: {total_token_count}.",
file=sys.stderr)
def split_sentences(self, file_name):
input_file_name, output_file_name = file_name
print("Opening", input_file_name)
fin = open(input_file_name, 'r', encoding='utf-8')
fout = open(output_file_name, 'w')
encoder = Encoder(self.args)
pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer)
split_docs = pool.imap(encoder.split, fin, 32)
proc_start = time.time()
total_bytes_processed = 0
total_token_count = 0
for i, (doc, bytes_processed, current_token_count) in enumerate(split_docs, start=1):
total_bytes_processed += bytes_processed
total_token_count += current_token_count
fout.write(doc + "\n")
self.print_processing_stats(i, proc_start, total_bytes_processed, total_token_count)
fin.close()
fout.close()
def process_json_file(self, file_name):
input_file_name, output_prefix = file_name
print("Opening", input_file_name)
fin = open(input_file_name, 'r', encoding='utf-8')
startup_start = time.time()
encoder = Encoder(self.args)
tokenizer = build_tokenizer(self.args)
pool = multiprocessing.Pool(self.workers, initializer=encoder.initializer)
encoded_docs = pool.imap(encoder.encode, fin, 32)
level = "document"
if self.args.split_sentences:
level = "sentence"
output_bin_files = {}
output_idx_files = {}
builders = {}
for key in self.args.json_keys:
output_bin_files[key] = "{}_{}_{}.bin".format(output_prefix,
key, level)
output_idx_files[key] = "{}_{}_{}.idx".format(output_prefix,
key, level)
builders[key] = indexed_dataset.IndexedDatasetBuilder(
output_bin_files[key],
dtype=indexed_dataset.DType.optimal_dtype(tokenizer.vocab_size),
)
startup_end = time.time()
proc_start = time.time()
total_bytes_processed = 0
total_token_count = 0 # add token count for process json file
print("Time to startup:", startup_end - startup_start)
for i, (doc, sentence_lens, bytes_processed, current_token_count) in enumerate(encoded_docs, start=1):
total_bytes_processed += bytes_processed
total_token_count += current_token_count # update token count
for key in doc.keys():
builders[key].add_document(doc[key], sentence_lens[key])
self.print_processing_stats(i, proc_start, total_bytes_processed, total_token_count)
print(f"Total token count: {total_token_count}") #print total token
token_count_queue.put(total_token_count)
fin.close()
builders[key].finalize(output_idx_files[key])
def get_args():
parser = argparse.ArgumentParser()
group = parser.add_argument_group(title='input data')
group.add_argument('--input', type=str, required=True,
help='Path to input JSON')
group.add_argument('--json-keys', nargs='+', default=['text'],
help='space separate listed of keys to extract from json')
group.add_argument('--split-sentences', action='store_true',
help='Split documents into sentences.')
group.add_argument('--keep-newlines', action='store_true',
help='Keep newlines between sentences when splitting.')
group = parser.add_argument_group(title='tokenizer')
group.add_argument('--tokenizer-type', type=str, required=False, default='GPT2BPETokenizer',
choices=['BertWordPieceLowerCase','BertWordPieceCase',
'GPT2BPETokenizer', 'SentencePieceTokenizer',
'GPTSentencePieceTokenizer', 'LLama2Tokenizer',
'NullTokenizer'],
help='What type of tokenizer to use.')
group.add_argument('--tokenizer-model', type=str, default=None,
help='YTTM tokenizer model.')
group.add_argument('--vocab-file', type=str, default=None,
help='Path to the vocab file')
group.add_argument('--vocab-size', default=786,
help='size of vocab for use with NullTokenizer')
group.add_argument('--merge-file', type=str, default=None,
help='Path to the BPE merge file (if necessary).')
group.add_argument('--append-eod', action='store_true',
help='Append an <eod> token to the end of a document.')
group.add_argument('--lang', type=str, default='english',
help='Language to use for NLTK-powered sentence splitting.')
group = parser.add_argument_group(title='output data')
group.add_argument('--output-prefix', type=str, required=True,
help='Path to binary output file without suffix')
group = parser.add_argument_group(title='runtime')
group.add_argument('--workers', type=int, required=True,
help=('Number of worker processes to launch.'
'A good default for fast pre-processing '
'is: (workers * partitions) = available CPU cores.'))
group.add_argument('--partitions', type=int, default=1,
help='Number of file partitions')
group.add_argument('--log-interval', type=int, default=1000,
help='Interval between progress updates')
group.add_argument('--keep-sequential-samples', action='store_true',
help='Ensure ordering of samples in .jsonl files is '
'preserved when using partitions>1.')
group.add_argument(
'--patch-tokenizer-type',
type=str,
required=True,
choices=['Qwen3Tokenizer', 'Qwen2Tokenizer', 'LLamaTokenizer', 'DeepSeekV2Tokenizer', 'LLama3Tokenizer', 'LLama2Tokenizer', 'GPT2BPETokenizer'],
help='What type of tokenizer to use.',
)
group.add_argument('--load',
type=str,
default=None,
help='path to tokenizer config file')
group.add_argument('--seq-length',
type=int,
default=2048,
help='sequence length')
group.add_argument('--extra-vocab-size',
type=int,
default=0,
help='extra_vocab_size')
args = parser.parse_args()
args.keep_empty = False
if args.tokenizer_type.lower().startswith('bert') and not args.split_sentences:
print("Are you sure you don't want to split sentences?")
# some default/dummy values for the tokenizer
args.rank = 1
args.make_vocab_size_divisible_by = 128
args.tensor_model_parallel_size = 1
args.vocab_extra_ids = 0
return args
def get_file_name(args, file_id):
file_name, extension = os.path.splitext(args.input)
input_file_name = file_name + "_" + str(file_id) + extension
sentence_split_file = file_name + "_ss_" + str(file_id) + extension
output_prefix = args.output_prefix + "_" + str(file_id)
file_names = {
'partition': input_file_name,
'sentence_split': sentence_split_file,
'output_prefix': output_prefix}
return file_names
def check_files_exist(in_ss_out_names, key, num_partitions):
for i in range(num_partitions):
if not os.path.exists(in_ss_out_names[i][key]):
return False
return True
def main():
args = get_args()
if args.split_sentences:
if nltk_available:
nltk.download("punkt", quiet=True, download_dir=os.environ.get("NLTK_DATA"))
else:
raise Exception(
"nltk library required for sentence splitting is not available.")
in_ss_out_names = []
if args.partitions == 1:
file_name, extension = os.path.splitext(args.input)
sentence_split_file = file_name + "_ss" + extension
file_names = {
'partition': args.input,
'sentence_split': sentence_split_file,
'output_prefix': args.output_prefix}
in_ss_out_names.append(file_names)
else:
file_list = os.listdir(args.input)
in_file_names = [os.path.join(args.input, file) for file in file_list]
# Count total number of lines across .jsonl files
if args.keep_sequential_samples:
total_sample_count = 0
for filename in in_file_names:
with open(filename, "r") as fin:
for fc, _ in enumerate(fin):
pass
total_sample_count += (fc + 1)
partition_size = math.ceil(total_sample_count / args.partitions)
# create .jsonl parition files
for idx in range(args.partitions):
in_ss_out_name = get_file_name(args, idx)
in_ss_out_names.append(in_ss_out_name)
# check to see if paritions were already created
partitions_present = check_files_exist(in_ss_out_names, 'partition', args.partitions)
# check to see if paritions with split sentences already created
split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions)
if not partitions_present and not split_sentences_present:
# populate .jsonl partition files from parent files
partitioned_input_files = []
for idx in range(args.partitions):
partitioned_input_file = open(in_ss_out_names[idx]['partition'], 'w')
partitioned_input_files.append(partitioned_input_file)
index = 0
if args.keep_sequential_samples: line_count = 0
for in_file_name in in_file_names:
# support for gzip files
if in_file_name.endswith(".gz"):
fin = gzip.open(in_file_name, 'rt')
else:
fin = open(in_file_name, 'r', encoding='utf-8')
for line in fin:
partitioned_input_files[index].write(line)
if args.keep_sequential_samples:
line_count += 1
if line_count % partition_size == 0:
index += 1
else:
index = (index + 1)%args.partitions
fin.close()
for idx in range(args.partitions):
partitioned_input_files[idx].close()
assert args.workers % args.partitions == 0
partition = Partition(args, args.workers//args.partitions)
# check to see if paritions with split sentences already created
split_sentences_present = check_files_exist(in_ss_out_names, 'sentence_split', args.partitions)
# split sentences in partition files
if args.split_sentences and not split_sentences_present:
processes = []
for name in in_ss_out_names:
p = multiprocessing.Process(target=partition.split_sentences,
args=((name['partition'], name['sentence_split']),))
p.start()
processes.append(p)
for p in processes:
p.join()
if args.partitions == 1:
return
# encode partition files in parallel
processes = []
input_key = 'sentence_split' if args.split_sentences else 'partition'
for name in in_ss_out_names:
p = multiprocessing.Process(target=partition.process_json_file,
args=((name[input_key], name['output_prefix']),))
p.start()
processes.append(p)
for p in processes:
p.join()
if args.partitions == 1:
return
# merge bin/idx partitions
level = "document"
if args.split_sentences:
level = "sentence"
output_bin_files = {}
output_idx_files = {}
builders = {}
tokenizer = build_tokenizer(args)
for key in args.json_keys:
output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix,
key, level)
output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix,
key, level)
builders[key] = indexed_dataset.IndexedDatasetBuilder(
output_bin_files[key],
dtype=indexed_dataset.DType.optimal_dtype(tokenizer.vocab_size),
)
for name in in_ss_out_names:
parition_output_prefix = name['output_prefix']
full_partition_output_prefix = "{}_{}_{}".format(parition_output_prefix,
key, level)
builders[key].add_index(full_partition_output_prefix)
builders[key].finalize(output_idx_files[key])
# count all process token num
total_token_count = 0
while not token_count_queue.empty():
total_token_count += token_count_queue.get()
print(f"Total tokens processed: {total_token_count}")
if __name__ == '__main__':
main()