in archived/smp-gpt-sharded-data-parallel/train.py [0:0]
def main(): # pylint: disable=too-many-branches,too-many-locals,too-many-statements
"""Main function to train GPT."""
args = parse_args()
if args.partition_assignment != "" and args.manual_partition == 0:
logging.warning("Partition_assignment is set, enable manual_partition.")
args.manual_partition = 1
# any value here is overriden by the config set in notebook when launching the sagemaker job
smp_config = {
"ddp": True,
"tensor_parallel_degree": args.tensor_parallel_degree,
"pipeline_parallel_degree": args.pipeline_parallel_degree,
"microbatches": args.microbatches,
"shard_optimizer_state": args.shard_optimizer_state > 0,
"prescaled_batch": args.prescaled_batch > 0,
"fp16": args.fp16 > 0,
"bf16": args.bf16 > 0,
"offload_activations": args.offload_activations > 0,
"delayed_parameter_initialization": args.delayed_param > 0,
"optimize": args.optimize,
"placement_strategy": args.placement_strategy,
"activation_loading_horizon": args.activation_loading_horizon,
"skip_tracing": args.skip_tracing > 0,
"auto_partition": not args.manual_partition,
"default_partition": 0,
"static_mode": args.static_mode > 0,
"fast_mode": args.fast_mode > 0,
"sharded_data_parallel_degree": args.sharded_data_parallel_degree,
"ddp_dist_backend": args.ddp_dist_backend,
"sdp_hierarchical_allgather": False,
"sdp_gradient_clipping": args.grad_clip,
}
if args.active_microbatches is not None:
smp_config["active_microbatches"] = args.active_microbatches
if args.log_param_norms and args.use_distributed_transformer == 1:
logging.warning(
"Script currently doesn't support logging param norms when using distributed transformer, disabling log_param_norms" # pylint: disable=line-too-long
)
smp.init(smp_config)
_show_env_vars(0)
if smp.rank() == 0:
logging.info("Arguments: %s", args.__dict__)
logging.info("Transformers version: %s", transformers.__version__)
logging.info(
"smdistributed.modelparallel version: %s", smdistributed.modelparallel.__version__
)
logging.info("smdistributed config: %s", smp_config)
if args.save_final_full_model and smp.rank() == 0:
logging.warning(
"Note that save_final_full_model only saves the final model at the end "
"of all steps. It does not save optimizer state. Optimizer state is only "
"saved with partial models which are saved at checkpointing_freq during "
"training. If you want to restart training you need partial checkpoints."
)
if args.partition_assignment != "":
partition_assignment = args.partition_assignment.split(",")
msg = (
f"partition_assignment must have the same size as pipeline parallel degree, "
f"but getting {len(partition_assignment)} vs {smp.pp_size()}"
)
logging.fatal("Will fail with: %s.", msg)
raise AssertionError(msg)
model_config, args = model_config_lib.get_model_config_from_args(
args.model_type, args.model_name, args, log=(smp.rank() == 0)
)
# the following improves start-up time by skipping proper initialization
# of weights in the original model. this is not a problem because DistributedModel
# will override those weights anyway when we use distributed transformer.
if args.use_distributed_transformer > 0:
from transformers.modeling_utils import ( # pylint: disable=import-error,import-outside-toplevel
PreTrainedModel,
)
PreTrainedModel.init_weights = lambda x: None
set_seed(args.seed)
if args.enable_memory_profiling > 0:
memory_status_cpu(msg="before model creation")
if args.fp16 and args.bf16:
raise ValueError("FP16 and BF16 cannot be simultaneously enabled.")
if args.fp16:
dtype = torch.float16 # pylint: disable=no-member
elif args.bf16:
dtype = torch.bfloat16 # pylint: disable=no-member
else:
dtype = torch.get_default_dtype() # pylint: disable=no-member
if args.fine_tune > 0 and args.delayed_param > 0 and smp.rank() == 0:
pretrained_model = AutoModelForCausalLM.from_pretrained(
args.model_name or args.model_dir
)
model_state_dict = pretrained_model.state_dict()
path = os.path.join(args.model_dir, "fullmodel.pt")
torch.save(model_state_dict, path)
smp.barrier()
# About zero_init:
# we only want to init with zero for actual model for training,
# in disttf case it's used in DistModel wrapper. for others we don't need to set zero init
# This is needed only to param_id_to_offset
with smp.model_creation(
tensor_parallelism=smp.tp_size() > 1 or args.use_distributed_transformer > 0,
zero_init=args.use_distributed_transformer == 0,
dtype=dtype,
distribute_embedding=args.sharded_data_parallel_degree > 1 and smp.tp_size() > 1,
use_alibi=args.alibi > 0,
attention_in_fp32=args.attention_in_fp32 > 0,
fp32_residual_addition=args.residual_addition_in_fp32 > 0,
query_key_layer_scaling=args.query_key_layer_scaling > 0 and args.bf16 < 1,
fused_softmax=args.fused_softmax > 0,
fused_dropout=args.fused_dropout > 0,
fused_bias_gelu=args.fused_bias_gelu > 0,
flash_attention=args.flash_attention > 0,
):
if args.fine_tune > 0 and args.delayed_param == 0:
model = AutoModelForCausalLM.from_pretrained(
args.model_name or args.model_dir
)
else:
model = AutoModelForCausalLM.from_config(model_config)
if args.enable_memory_profiling > 0:
memory_status_cpu(msg="after model creation")
# smdistributed: Set the device to the GPU ID used by the current process.
# Input tensors should be transferred to this device.
torch.cuda.set_device(smp.local_rank())
if not args.same_seed:
# Set seed by tp_rank to prevent weights from being the same on different tp_ranks
set_seed(args.seed + smp.tp_rank())
# smdistributed: Use the DistributedModel container to provide the model
# to be partitioned across different ranks. For the rest of the script,
# the returned DistributedModel object should be used in place of
# the model provided for DistributedModel class instantiation.
if args.enable_memory_profiling > 0:
memory_status_cpu(msg="before dist model creation")
model = smp.DistributedModel(
model, trace_device="gpu", backward_passes_per_step=args.gradient_accumulation
)
if args.enable_memory_profiling > 0:
memory_status_cpu(msg="after dist model creation")
m = model.get_module() # pylint: disable=invalid-name
num_params = compute_num_params(m)
if smp.rank() == 0:
logging.info("# total parameters: %s", num_params)
if args.use_distributed_transformer > 0:
transformer_layers = m.transformer.seq_layers
else:
if args.model_type in ["gpt2", "bloom"]:
transformer_layers = m.transformer.h
elif args.model_type == "gpt_neox":
transformer_layers = m.gpt_neox.layers
if args.manual_partition:
logging.debug("Manual partition enabled")
if args.partition_assignment != "":
get_num_layers = lambda x: int( # pylint: disable=unnecessary-lambda-assignment
partition_assignment[x]
)
total_layers = sum(get_num_layers(pp_rank) for pp_rank in range(smp.pp_size()))
msg = (
f"partition_assignment must have the same total transformer layers as model, "
f"but getting {total_layers} vs {args.num_layers}"
)
logging.fatal("Will fail with: %s.", msg)
raise AssertionError(msg)
# evenly distribute layers across all partitions
div, rem = divmod(args.num_layers, smp.pp_size())
get_num_layers = lambda x: ( # pylint: disable=unnecessary-lambda-assignment
div + 1 if x >= smp.pp_size() - rem else div
)
assignments = []
# (TODO) This is required for 175B otherwise a hang for partition "8,17,17,18,18,18"
# Need further investigation
# for pp_rank in reversed(range(smp.pp_size())):
for pp_rank in range(smp.pp_size()):
nl = get_num_layers(pp_rank) # pylint: disable=invalid-name
logging.debug("%s layers assigned to partition %d", nl, pp_rank)
assignments += [pp_rank for _ in range(nl)]
for i, c in enumerate(transformer_layers.children()): # pylint: disable=invalid-name
smp.set_partition(c, assignments[i])
param_groups = get_param_groups_by_weight_decay(m)
if args.use_adamw > 0:
optimizer = optim.AdamW(
param_groups, betas=(args.beta1, args.beta2), lr=args.lr, weight_decay=args.weight_decay
)
else:
optimizer = optim.Adam(
param_groups, betas=(args.beta1, args.beta2), lr=args.lr, weight_decay=args.weight_decay
)
if args.activation_checkpointing: # pylint: disable=too-many-nested-blocks
if args.use_distributed_transformer or smp.tp_size() > 1:
if args.checkpoint_sublayers:
for c in transformer_layers.children(): # pylint: disable=invalid-name
smp.set_activation_checkpointing(c.attention)
smp.set_activation_checkpointing(c.output)
else:
smp.set_activation_checkpointing(
transformer_layers, strategy=args.activation_strategy
)
else:
for c in transformer_layers.children(): # pylint: disable=invalid-name
if args.checkpoint_sublayers:
if args.model_type == "gpt2":
smp.set_activation_checkpointing(c.attn)
smp.set_activation_checkpointing(c.mlp)
elif args.model_type in ["gpt_neox", "bloom"]:
if args.model_type == "gpt_neox":
smp.set_activation_checkpointing(c.attention)
elif args.model_type == "bloom":
smp.set_activation_checkpointing(c.self_attention)
smp.set_activation_checkpointing(c.input_layernorm)
smp.set_activation_checkpointing(c.post_attention_layernorm)
smp.set_activation_checkpointing(c.mlp)
else:
smp.set_activation_checkpointing(c)
if args.sharded_data_parallel_degree > 1 and args.use_distributed_transformer == 0:
param_id_to_offset = build_param_id_to_offset(param_groups)
optimizer = smp.DistributedOptimizer(
optimizer,
static_loss_scale=None,
dynamic_loss_scale=True,
dynamic_loss_args={"scale_window": 1000, "min_scale": 1, "delayed_shift": 2},
)
if args.fine_tune > 0 and args.delayed_param > 0:
smp.resume_from_checkpoint(args.model_dir, tag="fullmodel.pt", partial=False)
if args.sharded_data_parallel_degree > 1 and args.use_distributed_transformer == 0:
param_id_to_buffer = build_param_id_to_buffer(optimizer, param_id_to_offset)
else:
param_id_to_buffer = None
lr_scheduler = get_learning_rate_scheduler(optimizer, args)
if args.enable_memory_profiling > 0:
model.register_post_partition_hook(
lambda model, optimizer: memory_status(msg="After partition")
)
# load after wrapping model and optimizer with smp Distributed...
if args.load_full or args.load_partial:
if args.load_partial and args.load_full:
logging.info(
"Since both --load_partial and --load_full set, will try to load from full "
"checkpoint. If the intention is to load from partial checkpoint, please don't set "
"--load_full"
)
partial = not args.load_full
path = args.checkpoint_dir if partial else args.model_dir
tag = None if partial else "fullmodel.pt"
user_content = smp.resume_from_checkpoint(path, tag=tag, partial=partial)
total_steps = user_content["total_steps"] if partial else 0
start_train_path_index = user_content.get("start_train_path_index", 0)
start_batch_index = user_content.get("start_batch_index", 0)
if "lr_scheduler" in user_content:
lr_scheduler.load_state_dict(user_content["lr_scheduler"])
else:
total_steps = 0
start_train_path_index = 0
start_batch_index = 0
# Add emty cache to clear memory when loaded with partial checkpointing
# for SDPTP and GPT NeoX
torch.cuda.empty_cache()
start = time.time()
total_steps, throughput, loss = train(
model,
optimizer,
lr_scheduler,
model_config,
start_train_path_index,
start_batch_index,
num_params,
total_steps,
args,
param_id_to_buffer,
)
time_to_train = time.time() - start
if args.ci:
logging.info("[SMP_METRIC]__GPT2__Time_to_train__%s", time_to_train)
logging.info("[SMP_METRIC]__GPT2__samples/second__%s", throughput)
logging.info("[SMP_METRIC]__GPT2__Loss__%s", loss)
if not args.load_partial and not args.load_full:
if time_to_train >= args.time_to_train:
msg = f"Time to train ({time_to_train}) >= threshold ({args.time_to_train})"
logging.fatal("Will fail with: %s.", msg)
raise AssertionError(msg)
if throughput <= args.throughput:
msg = f"Throughput ({throughput}) >= threshold ({args.throughput})"
logging.fatal("Will fail with: %s.", msg)
raise AssertionError(msg)
if args.loss and loss >= args.loss:
msg = f"Loss ({loss}) >= threshold ({args.loss})"
logging.fatal("Will fail with: %s.", msg)
raise AssertionError(msg)
if args.save_final_full_model:
# saves full model at the end
user_content = {
"cli_args": args.__dict__,
"num_params": num_params,
"total_steps": total_steps,
"model_config": model_config,
}
smp.save_checkpoint(
args.model_dir,
tag="fullmodel.pt",
partial=False,
model=model,
user_content=user_content,
)
smp.barrier()
if smp.rank() == 0:
logging.info("SMP training finished successfully")