in build_and_train_models/sm-distributed_model_parallel_v2/shared-scripts/data/prep/_prepare_nemo_megatron_dataset.py [0:0]
def main():
args = get_args()
startup_start = time.time()
if args.preproc_folder:
print("Searching folder for .json or .json.gz files...")
assert os.path.exists(args.input), f"Folder does not exist: {args.input}"
json_files = (str(f) for f in pathlib.Path(args.input).glob(args.files_filter))
json_files = [f for f in json_files if f.endswith(".json") or f.endswith(".json.gz")]
if len(json_files) == 0:
raise FileNotFoundError("No .json or .json.gz files found in folder.")
else:
print(f"Found {len(json_files)} .json or .json.gz files.")
else:
assert os.path.exists(args.input), f"File does not exist: {args.input}"
json_files = [args.input]
if nltk_available and args.split_sentences:
nltk.download("punkt", quiet=True)
encoder = Encoder(args)
if args.dataset_impl == "retmmap":
assert args.need_pad_id, "retmmap need --need_pad_id flag"
tokenizer = get_tokenizer(args)
level = "document"
if args.split_sentences:
level = "sentence"
print(f"Vocab size: {tokenizer.vocab_size}")
print(f"Output prefix: {args.output_prefix}")
output_bin_files = {}
output_idx_files = {}
builders = {}
for key in args.json_keys:
output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix, key, level)
output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix, key, level)
builders[key] = indexed_dataset.make_builder(
output_bin_files[key],
impl=args.dataset_impl,
chunk_size=args.chunk_size,
pad_id=tokenizer.pad_id if hasattr(tokenizer, "pad_id") else 0,
retrieval_db=args.retrieval_db,
vocab_size=tokenizer.vocab_size,
stride=args.chunk_stride_size,
)
startup_end = time.time()
proc_start = time.time()
total_bytes_processed = 0
print("Time to startup:", startup_end - startup_start)
pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)
for idx, json_file in enumerate(json_files):
print(f"Processing file {json_file} {idx + 1}/{len(json_files)}")
if json_file.endswith(".gz"):
fin = gzip.open(json_file, "r")
else:
fin = open(args.input, "r", encoding="utf-8")
encoded_docs = pool.imap(encoder.encode, fin, 25)
for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
total_bytes_processed += bytes_processed
for key, sentences in doc.items():
if len(sentences) == 0:
continue
for sentence in sentences:
builders[key].add_item(torch.IntTensor(sentence))
builders[key].end_document()
if i % args.log_interval == 0:
current = time.time()
elapsed = current - proc_start
mbs = total_bytes_processed / elapsed / 1024 / 1024
print(
f"Processed {i} documents",
f"({i/elapsed} docs/s, {mbs} MB/s).",
file=sys.stderr,
)
for key in args.json_keys:
builders[key].finalize(output_idx_files[key])