in sagemaker/25_pytorch_fsdp_model_parallelism/scripts/run_clm.py [0:0]
def training_function(args):
# set seed
set_seed(args.seed)
dataset = load_from_disk(args.dataset_path)
# load model from the hub
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
use_cache=False if args.gradient_checkpointing else True, # this is needed for gradient checkpointing
)
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
# Define training args
output_dir = "/tmp"
training_args = TrainingArguments(
output_dir=output_dir,
overwrite_output_dir=True,
per_device_train_batch_size=args.per_device_train_batch_size,
bf16=args.bf16, # Use BF16 if available
learning_rate=args.lr,
num_train_epochs=args.epochs,
gradient_checkpointing=args.gradient_checkpointing,
# logging strategies
logging_dir=f"{output_dir}/logs",
logging_strategy="steps",
logging_steps=10,
save_strategy="no",
optim=args.optimizer,
ddp_timeout=7200,
fsdp=args.fsdp,
fsdp_transformer_layer_cls_to_wrap=args.fsdp_transformer_layer_cls_to_wrap,
)
# Create Trainer instance
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=default_data_collator,
)
# Start training
trainer.train()
print("Training done!")
# save model and tokenizer for easy inference
safe_save_model_for_hf_trainer(trainer, tokenizer, "/opt/ml/model/")
dist.barrier()