notebooks/text-generation/scripts/run_clm.py (47 lines of code) (raw):

from dataclasses import dataclass, field from datasets import load_from_disk from transformers import ( AutoModelForCausalLM, AutoTokenizer, default_data_collator, set_seed, ) from optimum.neuron import NeuronHfArgumentParser as HfArgumentParser from optimum.neuron import NeuronTrainer as Trainer from optimum.neuron import NeuronTrainingArguments as TrainingArguments def training_function(script_args, training_args): # load dataset dataset = load_from_disk(script_args.dataset_path) # load model from the hub with a bnb config tokenizer = AutoTokenizer.from_pretrained(script_args.model_id) model = AutoModelForCausalLM.from_pretrained( script_args.model_id, torch_dtype="auto", low_cpu_mem_usage=True, use_cache=False if training_args.gradient_checkpointing else True, ) # Create Trainer instance trainer = Trainer( model=model, tokenizer=tokenizer, args=training_args, train_dataset=dataset, data_collator=default_data_collator, # no special collator needed since we stacked the dataset ) # Start training trainer.train() trainer.save_model() # Saves the tokenizer too for easy upload # Consolidate sharded checkpoint files to single file when TP degree > 1 # perrysc@amazon.com # if (int(os.environ.get("RANK", -1)) == 0) and int(training_args.tensor_parallel_size) > 1: # print("Converting sharded checkpoint to consolidated format") # from optimum.neuron.models.training.checkpointing import ( # consolidate_model_parallel_checkpoints_to_unified_checkpoint, # ) # from shutil import rmtree # consolidate_model_parallel_checkpoints_to_unified_checkpoint( # training_args.output_dir, training_args.output_dir, "pytorch" # ) # rmtree(os.path.join(training_args.output_dir, "tensor_parallel_shards")) # remove sharded checkpoint files @dataclass class ScriptArguments: model_id: str = field( metadata={ "help": "The model that you want to train from the Hugging Face hub. E.g. gpt2, gpt2-xl, bert, etc." }, ) dataset_path: str = field( metadata={"help": "Path to the preprocessed and tokenized dataset."}, default=None, ) def main(): parser = HfArgumentParser([ScriptArguments, TrainingArguments]) script_args, training_args = parser.parse_args_into_dataclasses() # set seed set_seed(training_args.seed) # run training function training_function(script_args, training_args) if __name__ == "__main__": main()