in training/distributed_training/pytorch/model_parallel/gpt2/train_gpt_simple.py [0:0]
def main():
args = parse_args()
if args.shard_optimizer_state > 0 and not args.skip_full_optimizer:
raise ValueError(
"If shard_optimizer_state is enabled, skip_full_optimizer must also be enabled. Full optimizer saving is currently not supported under optimizer state sharding."
)
# any value here is overriden by the config set in notebook when launching the sagemaker job
smp_config = {
"ddp": True,
"tensor_parallel_degree": args.tensor_parallel_degree,
"pipeline_parallel_degree": args.pipeline_parallel_degree,
"microbatches": args.microbatches,
# if activation_checkpointing true checkpoints transformer layers below
"checkpoint_attentions": False if args.activation_checkpointing else True,
"shard_optimizer_state": args.shard_optimizer_state > 0,
"prescaled_batch": args.prescaled_batch > 0,
"_match_weights": args.match_weights > 0,
"fp16_params": args.fp16 > 0,
"offload_activations": args.offload_activations > 0,
"optimize": args.optimize,
"auto_partition": False if args.manual_partition else True,
"default_partition": 0,
"_fp32_grad_accumulation": args.fp32_grad_accumulation > 0,
"static_mode": args.static_mode > 0,
"fast_mode": args.fast_mode > 0,
}
if args.active_microbatches is not None:
smp_config["active_microbatches"] = args.active_microbatches
smp.init(smp_config)
if smp.rank() == 0:
print("Arguments:", args.__dict__)
print(f"Transformers version: {transformers.__version__}")
print(f"smdistributed.modelparallel version: {smdistributed.modelparallel.__version__}")
print(f"smdistributed config: {smp_config}")
if args.save_final_full_model and smp.rank() == 0:
print(
f"[Warning] Note that save_final_full_model only saves the final model at the end of all steps. It does not save optimizer state. Optimizer state is only saved with partial models which are saved at checkpointing_freq during training. If you want to restart training you need partial checkpoints."
)
if smp.local_rank() == 0:
for path in [args.model_dir, args.checkpoint_dir]:
if not os.path.exists(path):
os.makedirs(path, exist_ok=True)
model_config = GPT2Config(
vocab_size=50257,
n_positions=args.max_context_width,
n_ctx=args.max_context_width,
n_embd=args.hidden_width,
n_layer=args.num_layers,
n_head=args.num_heads,
n_inner=None,
activation_function="gelu_new",
resid_pdrop=args.resid_pdrop,
embd_pdrop=args.embd_pdrop,
attn_pdrop=args.attn_pdrop,
layer_norm_epsilon=1e-05,
initializer_range=0.02,
summary_type="cls_index",
summary_use_proj=True,
summary_activation=None,
summary_proj_to_labels=True,
summary_first_dropout=args.summary_first_pdrop,
# gradient_checkpointing=args.gradient_checkpointing > 0,
use_cache=True,
bos_token_id=50256,
eos_token_id=50256,
return_dict=True,
)
# the following improves start-up time by skipping proper initialization
# of weights in the original model. this is not a problem because DistributedModel
# will override those weights anyway when tensor_parallel_degree > 1.
if smp.tp_size() > 1 and args.match_weights < 1:
from transformers.modeling_utils import PreTrainedModel
PreTrainedModel.init_weights = lambda x: None
set_seed(args.seed)
if args.fp16:
torch.set_default_dtype(torch.float16)
with smp.tensor_parallelism(
enabled=smp.tp_size() > 1, attention_in_fp32=args.attention_in_fp32 > 0
):
with smp.delay_param_initialization(
enabled=(smp.tp_size() > 1 and args.match_weights < 1 and args.delayed_param > 0)
):
model = AutoModelForCausalLM.from_config(model_config)
torch.set_default_dtype(torch.float32)
if args.fp16:
model = FP16_Module(model)
num_params = sum([np.prod(p.size()) for p in model.parameters()])
if smp.rank() == 0:
print(f"# total parameters: {num_params}")
# smdistributed: Set the device to the GPU ID used by the current process.
# Input tensors should be transferred to this device.
torch.cuda.set_device(smp.local_rank())
device = torch.device("cuda")
if not args.same_seed:
# Set seed by tp_rank to prevent weights from being the same on different tp_ranks
set_seed(args.seed + smp.tp_rank())
# smdistributed: Use the DistributedModel container to provide the model
# to be partitioned across different ranks. For the rest of the script,
# the returned DistributedModel object should be used in place of
# the model provided for DistributedModel class instantiation.
if args.fp16:
torch.set_default_dtype(torch.float16)
model = smp.DistributedModel(model, trace_device="gpu")
if args.fp16:
m = model.module
else:
m = model
if smp.tp_size() > 1:
transformer_layers = m.module.module.transformer.seq_layers
else:
transformer_layers = m.module.module.transformer.h
if args.manual_partition:
print(f"Manual partition enabled")
# evenly distribute layers across all partitions
div, rem = divmod(args.num_layers, smp.pp_size())
get_num_layers = lambda x: (div + 1 if x >= smp.pp_size() - rem else div)
assignments = []
for pp_rank in range(smp.pp_size()):
nl = get_num_layers(pp_rank)
print(f"{nl} layers assigned to partition {pp_rank}")
assignments += [pp_rank for _ in range(nl)]
for i, c in enumerate(transformer_layers.children()):
smp.set_partition(c, assignments[i])
torch.set_default_dtype(torch.float32)
iter_model = model
# Build parameter groups (weight decay and non-decay).
while isinstance(iter_model, (DistributedDataParallel, FP16_Module)):
iter_model = iter_model.module
param_groups = get_param_groups_by_weight_decay(iter_model)
if args.use_adamw > 0:
optimizer = optim.AdamW(
param_groups, betas=(args.beta1, args.beta2), lr=args.lr, weight_decay=args.weight_decay
)
else:
optimizer = optim.Adam(
param_groups, betas=(args.beta1, args.beta2), lr=args.lr, weight_decay=args.weight_decay
)
if args.activation_checkpointing:
kwargs = {}
if isinstance(transformer_layers, nn.Sequential):
kwargs["pack_args_as_tuple"] = True
kwargs["strategy"] = args.activation_strategy
smp.set_activation_checkpointing(transformer_layers, **kwargs)
if args.fp16:
optimizer = FP16_Optimizer(
model,
optimizer,
static_loss_scale=None,
dynamic_loss_scale=True,
use_smp=True,
dynamic_loss_args={"scale_window": 1000, "min_scale": 1, "delayed_shift": 2},
params_have_main_grad=args.fp32_grad_accumulation > 0,
shard_optimizer_state=args.shard_optimizer_state > 0,
)
optimizer = smp.DistributedOptimizer(optimizer)
lr_scheduler = get_learning_rate_scheduler(optimizer, args)
if args.fp16:
model.register_post_step_hook(lambda model, optimizer: optimizer.init_master_params())
# load after wrapping model and optimizer with smp Distributed...
if args.load_full or args.load_partial:
if args.load_partial and args.load_full:
print(
"Since both --load_partial and --load_full set, will try to load from full checkpoint."
"If the intention is to load from partial checkpoint, please don't set --load_full"
)
partial = not args.load_full
path = args.checkpoint_dir if partial else args.model_dir
translate_from_hf = not partial
model, optimizer, total_steps, start_train_path_index, start_batch_index = load_model_and_optimizer(
path,
model,
optimizer,
lr_scheduler,
partial,
args,
translate_from_hf=translate_from_hf,
seq_length=args.max_context_width,
load_model=True,
load_optimizer=args.load_partial > 0,
num_params=num_params,
)
if args.save_or_verify_ckptsum:
filename = "saved_sum" if args.load_full else "saved_partial_sum"
load_and_verify_ckptsum(
args, model, optimizer, filename=os.path.join(args.model_dir, filename)
)
else:
total_steps = 0
start_train_path_index = 0
start_batch_index = 0
start = time.time()
total_steps, throughput, loss = train(
model,
optimizer,
lr_scheduler,
model_config,
start_train_path_index,
start_batch_index,
num_params,
total_steps,
args,
)
time_to_train = time.time() - start
if args.ci:
print(f"[SMP_METRIC]__GPT2__Time_to_train__{time_to_train}")
print(f"[SMP_METRIC]__GPT2__samples/second__{throughput}")
print(f"[SMP_METRIC]__GPT2__Loss__{loss}")
if not args.load_partial and not args.load_full:
assert time_to_train < args.time_to_train
assert throughput > args.throughput
if args.loss:
assert loss < args.loss
if args.save_final_full_model:
# saves full model at the end
base_path = f"trained_gpt_nparams-{num_params}_steps-{total_steps}.pt"
out_path = os.path.join(args.model_dir, base_path)
if args.save_or_verify_ckptsum:
# Save optimizer and model tensor sums and scalars before saving
save_ckptsum(args, model, optimizer, filename=os.path.join(args.model_dir, "saved_sum"))
if smp.rdp_rank() == 0:
save(
out_path,
model,
optimizer,
lr_scheduler,
model_config,
num_params,
total_steps,
-1,
args,
partial=False,
translate_to_hf=smp.tp_size() > 1,
seq_length=args.max_context_width,
)
smp.barrier()
if smp.rank() == 0:
print("SMP training finished successfully")