def main()

in build_and_train_models/sm-distributed_model_parallel_v2/shared-scripts/data/prep/_prepare_nemo_megatron_dataset.py [0:0]


def main():
    args = get_args()
    startup_start = time.time()
    if args.preproc_folder:
        print("Searching folder for .json or .json.gz files...")
        assert os.path.exists(args.input), f"Folder does not exist: {args.input}"
        json_files = (str(f) for f in pathlib.Path(args.input).glob(args.files_filter))
        json_files = [f for f in json_files if f.endswith(".json") or f.endswith(".json.gz")]
        if len(json_files) == 0:
            raise FileNotFoundError("No .json or .json.gz files found in folder.")
        else:
            print(f"Found {len(json_files)} .json or .json.gz files.")
    else:
        assert os.path.exists(args.input), f"File does not exist: {args.input}"
        json_files = [args.input]

    if nltk_available and args.split_sentences:
        nltk.download("punkt", quiet=True)

    encoder = Encoder(args)

    if args.dataset_impl == "retmmap":
        assert args.need_pad_id, "retmmap need --need_pad_id flag"
    tokenizer = get_tokenizer(args)

    level = "document"
    if args.split_sentences:
        level = "sentence"

    print(f"Vocab size: {tokenizer.vocab_size}")
    print(f"Output prefix: {args.output_prefix}")
    output_bin_files = {}
    output_idx_files = {}
    builders = {}
    for key in args.json_keys:
        output_bin_files[key] = "{}_{}_{}.bin".format(args.output_prefix, key, level)
        output_idx_files[key] = "{}_{}_{}.idx".format(args.output_prefix, key, level)
        builders[key] = indexed_dataset.make_builder(
            output_bin_files[key],
            impl=args.dataset_impl,
            chunk_size=args.chunk_size,
            pad_id=tokenizer.pad_id if hasattr(tokenizer, "pad_id") else 0,
            retrieval_db=args.retrieval_db,
            vocab_size=tokenizer.vocab_size,
            stride=args.chunk_stride_size,
        )

    startup_end = time.time()
    proc_start = time.time()
    total_bytes_processed = 0
    print("Time to startup:", startup_end - startup_start)

    pool = multiprocessing.Pool(args.workers, initializer=encoder.initializer)

    for idx, json_file in enumerate(json_files):
        print(f"Processing file {json_file} {idx + 1}/{len(json_files)}")
        if json_file.endswith(".gz"):
            fin = gzip.open(json_file, "r")
        else:
            fin = open(args.input, "r", encoding="utf-8")

        encoded_docs = pool.imap(encoder.encode, fin, 25)

        for i, (doc, bytes_processed) in enumerate(encoded_docs, start=1):
            total_bytes_processed += bytes_processed
            for key, sentences in doc.items():
                if len(sentences) == 0:
                    continue
                for sentence in sentences:
                    builders[key].add_item(torch.IntTensor(sentence))
                builders[key].end_document()
            if i % args.log_interval == 0:
                current = time.time()
                elapsed = current - proc_start
                mbs = total_bytes_processed / elapsed / 1024 / 1024
                print(
                    f"Processed {i} documents",
                    f"({i/elapsed} docs/s, {mbs} MB/s).",
                    file=sys.stderr,
                )

    for key in args.json_keys:
        builders[key].finalize(output_idx_files[key])