in archived/fully_sharded_data_parallel-falcon/scripts/train.py [0:0]
def training_function(args):
# set seed
set_seed(args.seed)
dataset = load_from_disk(args.dataset_path)
# load model from the hub
config = FalconConfig(vocab_size=65024,
use_cache=True,
parallel_attn=True,
num_hidden_layers=32,
num_attention_heads=71,
new_decoder_architecture=False,
multi_query=True,
layer_norm_epsilon=1e-05,
initializer_range=0.02,
hidden_size=4544,
hidden_dropout=0.0,
eos_token_id=11,
bos_token_id=11,
bias=False)
model = AutoModelForCausalLM.from_config(config)
tokenizer = AutoTokenizer.from_pretrained(args.model_id)
train_dataset = dataset["train"]
eval_dataset = dataset["validation"]
train_dataloader,eval_dataloader = create_dataloaders(train_dataset,eval_dataset,args.rank,args.world_size,args.seed,args.per_device_train_batch_size,args.per_device_train_batch_size)
auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
FalconDecoderLayer
},
)
torch.cuda.set_device(args.local_rank)
dtype = torch.bfloat16
mixed_precision_policy = MixedPrecision(param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=dtype)
model = FSDP(
model,
auto_wrap_policy=auto_wrap_policy,
mixed_precision=mixed_precision_policy,
sharding_strategy=ShardingStrategy.FULL_SHARD,
backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
forward_prefetch=args.forward_prefetch,
limit_all_gathers=args.limit_all_gathers,
device_id=torch.cuda.current_device(),
)
non_reentrant_wrapper = functools.partial(checkpoint_wrapper, offload_to_cpu=True,
checkpoint_impl=CheckpointImpl.NO_REENTRANT)
check_fn_gpt = lambda submodule: isinstance(submodule, FalconDecoderLayer)
apply_activation_checkpointing(model, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn_gpt)
# Optimizer
# Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"]
optimizer_grouped_parameters = [
{
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{
"params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
"weight_decay": 0.0,
},
]
optimizer = torch.optim.AdamW(optimizer_grouped_parameters, lr=args.learning_rate)
# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
if args.rank==0:
print(f"Number of update steps per epoch {num_update_steps_per_epoch}")
if args.max_train_steps is None:
args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
overrode_max_train_steps = True
lr_scheduler = get_scheduler(
name=args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=args.num_warmup_steps * args.gradient_accumulation_steps,
num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
)
start = time.time()
device = torch.device(f"cuda:{args.local_rank}")
for epoch in range(args.num_train_epochs):
model.train()
total_steps=0
fsdp_loss = torch.zeros(2).to(args.local_rank)
for _, batch in enumerate(tqdm(train_dataloader,disable=not (args.rank==0))):
batch = {k: v.to(device) for k, v in batch.items()}
output = model(**batch)
loss = output["loss"]
loss.backward()
fsdp_loss[0] += loss.item()
fsdp_loss[1] += len(batch["input_ids"])
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
total_steps += 1
if args.max_steps is not None and total_steps > args.max_steps:
break
torch.distributed.all_reduce(fsdp_loss, op=torch.distributed.ReduceOp.SUM)
train_loss = fsdp_loss[0] / fsdp_loss[1]
train_ppl = torch.exp(train_loss)
if args.rank==0:
print(f"******{epoch=}: {train_ppl=} {train_loss=}******")
model.eval()
eval_loss = 0
fsdp_eval_loss = torch.zeros(2).to(args.local_rank)
for steps, batch in enumerate(tqdm(eval_dataloader,disable=not (args.rank==0))):
batch = {k: v.to(device) for k, v in batch.items()}
with torch.no_grad():
outputs = model(**batch)
loss = outputs["loss"]
fsdp_eval_loss[0] += loss.item()
fsdp_eval_loss[1] += len(batch["input_ids"])
if args.max_steps is not None and steps > args.max_steps:
break
torch.distributed.all_reduce(fsdp_eval_loss, op=torch.distributed.ReduceOp.SUM)
eval_loss = fsdp_eval_loss[0] / fsdp_eval_loss[1]
eval_ppl = torch.exp(eval_loss)
if args.rank==0:
print(f"*******{epoch=}: {eval_ppl=} {eval_loss=}*******")
if args.max_steps is not None and total_steps > args.max_steps:
break
save_model(model,tokenizer,args.model_dir,args.rank)
if args.rank == 0:
print("Training done!")
dist.barrier()