sagemaker/25_pytorch_fsdp_model_parallelism/scripts/run_clm.py (107 lines of code) (raw):
import os
import argparse
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
set_seed,
default_data_collator,
)
from datasets import load_from_disk
import torch
from transformers import Trainer, TrainingArguments
import torch.distributed as dist
def safe_save_model_for_hf_trainer(trainer: Trainer, tokenizer: AutoTokenizer, output_dir: str):
"""Helper method to save model for HF Trainer."""
# see: https://github.com/tatsu-lab/stanford_alpaca/issues/65
from torch.distributed.fsdp import (
FullyShardedDataParallel as FSDP,
FullStateDictConfig,
StateDictType,
)
model = trainer.model
save_policy = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, save_policy):
cpu_state_dict = model.state_dict()
if trainer.args.should_save:
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
tokenizer.save_pretrained(output_dir)
def parse_arge():
"""Parse the arguments."""
parser = argparse.ArgumentParser()
# add model id and dataset path argument
parser.add_argument(
"--model_id",
type=str,
default="google/flan-t5-xl",
help="Model id to use for training.",
)
parser.add_argument("--dataset_path", type=str, default="lm_dataset", help="Path to dataset.")
# add training hyperparameters for epochs, batch size, learning rate, and seed
parser.add_argument("--epochs", type=int, default=3, help="Number of epochs to train for.")
parser.add_argument("--max_steps", type=int, default=None, help="Number of epochs to train for.")
parser.add_argument(
"--per_device_train_batch_size",
type=int,
default=1,
help="Batch size to use for training.",
)
parser.add_argument("--lr", type=float, default=3e-5, help="Learning rate to use for training.")
parser.add_argument("--optimizer", type=str, default="adamw_hf", help="Learning rate to use for training.")
parser.add_argument("--seed", type=int, default=42, help="Seed to use for training.")
parser.add_argument(
"--gradient_checkpointing",
type=bool,
default=True,
help="Path to deepspeed config file.",
)
parser.add_argument(
"--bf16",
type=bool,
default=True if torch.cuda.get_device_capability()[0] == 8 else False,
help="Whether to use bf16.",
)
parser.add_argument("--fsdp", type=str, default=None, help="Whether to use fsdp.")
parser.add_argument(
"--fsdp_transformer_layer_cls_to_wrap",
type=str,
default=None,
help="Which transformer layer to wrap with fsdp.",
)
args = parser.parse_known_args()
return args
def training_function(args):
# set seed
set_seed(args.seed)
dataset = load_from_disk(args.dataset_path)
# load model from the hub
model = AutoModelForCausalLM.from_pretrained(
args.model_id,
use_cache=False if args.gradient_checkpointing else True, # this is needed for gradient checkpointing
)
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
# Define training args
output_dir = "/tmp"
training_args = TrainingArguments(
output_dir=output_dir,
overwrite_output_dir=True,
per_device_train_batch_size=args.per_device_train_batch_size,
bf16=args.bf16, # Use BF16 if available
learning_rate=args.lr,
num_train_epochs=args.epochs,
gradient_checkpointing=args.gradient_checkpointing,
# logging strategies
logging_dir=f"{output_dir}/logs",
logging_strategy="steps",
logging_steps=10,
save_strategy="no",
optim=args.optimizer,
ddp_timeout=7200,
fsdp=args.fsdp,
fsdp_transformer_layer_cls_to_wrap=args.fsdp_transformer_layer_cls_to_wrap,
)
# Create Trainer instance
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset,
data_collator=default_data_collator,
)
# Start training
trainer.train()
print("Training done!")
# save model and tokenizer for easy inference
safe_save_model_for_hf_trainer(trainer, tokenizer, "/opt/ml/model/")
dist.barrier()
def main():
args, _ = parse_arge()
training_function(args)
if __name__ == "__main__":
main()