ablations/tokenization/launch_tokenization.py (63 lines of code) (raw):

import argparse parser = argparse.ArgumentParser("Quickly launch thom's style of tokenization.") parser.add_argument( "data_path", type=str, help="Path to the data to tokenize." ) parser.add_argument( "output_name", type=str, help="Output name." ) parser.add_argument( "--n_tasks", type=int, help="nb of tokenization tasks", default=1000 ) parser.add_argument( "--max_toks", type=int, help="max tokens per file", default=1e8 ) parser.add_argument( "--tokenizer", type=str, help="tokenizer to use", default="google/gemma-2b" ) parser.add_argument( "--text_key", type=str, default="text" ) parser.add_argument( "--sample", type=float, default=1.0 ) parser.add_argument("--qos", type=str, default="normal", help="qos to use") parser.add_argument( "--jsonl_output", "-jo", type=str, default=None, help="Path to optionally save the sampled data jsonl" ) parser.add_argument("-d", help="dependency job", type=str, default=None) if __name__ == "__main__": args = parser.parse_args() from datatrove.executor import SlurmPipelineExecutor from datatrove.pipeline.filters import SamplerFilter from datatrove.pipeline.readers import JsonlReader from datatrove.pipeline.writers import JsonlWriter from datatrove.pipeline.tokens.tokenizer import DocumentTokenizer SlurmPipelineExecutor( # job_name=f"nd-{DUMP_NUMBER}-{len(DUMPS)}", job_name=f"tok-{args.output_name}", pipeline=[ JsonlReader( args.data_path, text_key=args.text_key, ), SamplerFilter(rate=args.sample), *([JsonlWriter(args.jsonl_output)] if args.jsonl_output else []), DocumentTokenizer( output_folder=f"/path/to/tokenized/{args.output_name}", local_working_dir=f"/scratch/$USER/multilingual/tok/{args.output_name}", tokenizer_name_or_path=args.tokenizer, eos_token=None, batch_size=10000, max_tokens_per_file=args.max_toks, # Max 1 GT per file (i.e. btw 5 et 300 tokenized files per dump et about 100 dump extracts per merged file) shuffle=True, ), ], tasks=args.n_tasks, time="2:00:00", partition="hopper-cpu", logging_dir=f"/path/to/logs/multilingual/toks/{args.output_name}", cpus_per_task=32, qos=args.qos, mem_per_cpu_gb=3, depends_job_id=args.d, ).run()