in build_and_train_models/sm-model_trainer/distributed-training/scripts/run_clm_no_trainer.py [0:0]
def training_function(args):
# set seed
set_seed(args['seed'])
from huggingface_hub.hf_api import HfFolder
print(f"Loading dataset from {args['dataset_path']}")
dataset = load_from_disk(f"file://{args['dataset_path']}")
dist.barrier()
# load model from the hub
model = AutoModelForCausalLM.from_pretrained(
args['model_id'],
cache_dir=args['cache_dir'],
use_cache=(
False if args['gradient_checkpointing'] else True
), # this is needed for gradient checkpointing
)
tokenizer = AutoTokenizer.from_pretrained(args['model_id'])
train_dataset = dataset['train']
eval_dataset = dataset['validation']
# Create dataloaders for training and evaluation
train_dataloader, eval_dataloader = create_dataloaders(
train_dataset,
eval_dataset,
args['rank'],
args['world_size'],
args['seed'],
args['per_device_train_batch_size'],
args['per_device_train_batch_size'],
)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
get_module_class_from_name(model, args['fsdp_transformer_layer_cls_to_wrap'])
},
)
torch.cuda.set_device(args['local_rank'])
dtype = torch.bfloat16
mixed_precision_policy = MixedPrecision(
param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype
)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD, # SHARD_GRAD_OP #for Zero2 and FULL_SHARD for Zero3,
cpu_offload=CPUOffload(offload_params=True),
backward_prefetch=BackwardPrefetch.BACKWARD_PRE, # BACKWARD_POST, BACKWARD_PRE
forward_prefetch=args['forward_prefetch'],
limit_all_gathers=args['limit_all_gathers'],
device_id=torch.cuda.current_device(),
)
non_reentrant_wrapper = functools.partial(
checkpoint_wrapper, offload_to_cpu=True, checkpoint_impl=CheckpointImpl.NO_REENTRANT
)
check_fn_gpt = lambda submodule: isinstance(
submodule, get_module_class_from_name(model, args['fsdp_transformer_layer_cls_to_wrap'])
)
apply_activation_checkpointing(
model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn_gpt
)
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"]
optimizer_grouped_parameters = [
{
"params": [
p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)
],
"weight_decay": args['weight_decay'],
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args['learning_rate'])
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(
len(train_dataloader) / args['gradient_accumulation_steps']
)
if args['rank'] == 0:
print(f"Number of update steps per epoch {num_update_steps_per_epoch}")
if args['max_train_steps'] is None:
args['max_train_steps'] = args['num_train_epochs'] * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
name=args['lr_scheduler_type'],
optimizer=optimizer,
num_warmup_steps=args['num_warmup_steps'] * args['gradient_accumulation_steps'],
num_training_steps=args['max_train_steps'] * args['gradient_accumulation_steps'],
)
start = time.time()
device = torch.device(f"cuda:{args['local_rank']}")
# Perform Training Loop for num_train_epochs times
for epoch in range(args['num_train_epochs']):
model.train()
total_steps = 0
fsdp_loss = torch.zeros(2).to(args['local_rank'])
# Use train_dataloader to get the batch data
for _, batch in enumerate(tqdm(train_dataloader, disable=not (args['rank'] == 0))):
batch = {k: v.to(device) for k, v in batch.items()}
output = model(**batch)
loss = output["loss"]
loss.backward()
fsdp_loss[0] += loss.item()
fsdp_loss[1] += len(batch["input_ids"])
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
total_steps += 1
if total_steps > args['max_steps']:
break
# Reduce the loss across all processes
torch.distributed.all_reduce(fsdp_loss, op=torch.distributed.ReduceOp.SUM)
train_loss = fsdp_loss[0] / fsdp_loss[1]
train_ppl = torch.exp(train_loss)
if args['rank'] == 0:
print(f"******{epoch=}: {train_ppl=} {train_loss=}******")
model.eval()
eval_loss = 0
fsdp_eval_loss = torch.zeros(2).to(args['local_rank'])
for steps, batch in enumerate(tqdm(eval_dataloader, disable=not (args['rank'] == 0))):
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = model(**batch)
loss = outputs["loss"]
fsdp_eval_loss[0] += loss.item()
fsdp_eval_loss[1] += len(batch["input_ids"])
if steps > args['max_steps']:
break
torch.distributed.all_reduce(fsdp_eval_loss, op=torch.distributed.ReduceOp.SUM)
eval_loss = fsdp_eval_loss[0] / fsdp_eval_loss[1]
eval_ppl = torch.exp(eval_loss)
if args['rank'] == 0:
print(f"*******{epoch=}: {eval_ppl=} {eval_loss=}*******")
save_model(model, tokenizer, args['model_dir'], args['rank'])
if args['rank'] == 0:
print("Training done!")
dist.barrier()