in build_and_train_models/sm-distributed_model_parallel_v2/shared-scripts/train_lib.py [0:0]
def main(args):
"""Main function to train GPT."""
global_start_time = time.time()
# Sanity check for args.
# - Checkpoints.
ckpt_lens = (
len(args.checkpoint_dir),
len(args.checkpoint_freq),
len(args.num_kept_checkpoints),
)
if len(set(ckpt_lens)) != 1:
raise ValueError(f"Len mismtach for checkpoint dir, freq vs num to keep: {ckpt_lens}.")
if args.distributed_backend == "smddp":
import smdistributed.dataparallel.torch.torch_smddp # pylint: disable=unused-import
dist.init_process_group(args.distributed_backend, timeout=datetime.timedelta(seconds=7200))
global_rank = dist.get_rank()
device = global_rank % torch.cuda.device_count()
world_size = dist.get_world_size()
# Reset all SMP related args if use_smp_implementation=0
if args.use_smp_implementation == 0:
tsm.state.tensor_parallel_degree = 1
tsm.state.expert_parallel_degree = 1
tsm.state.context_parallel_degree = 1
args.moe = 0
args.fp8 = 0
print_dict = {
"tensor_parallel_degree": tsm.state.tensor_parallel_degree,
"expert_parallel_degree": tsm.state.expert_parallel_degree,
"context_parallel_degree": tsm.state.context_parallel_degree,
"moe": args.moe,
"fp8": args.fp8,
}
if global_rank == 0:
logger.warn(f"use_smp_implementation is set to 0. Resetting these params to default values: {print_dict}")
if args.tensorboard_dir and global_rank == 0:
from torch.utils.tensorboard import SummaryWriter
logger.info("Writing metrics for tensorboard to %s.", args.tensorboard_dir)
writers = tuple(SummaryWriter(log_dir=tb_dir) for tb_dir in args.tensorboard_dir)
table_str = create_args_table(args.__dict__)
for writer in writers:
writer.add_text("Arguments", table_str)
else:
writers = ()
if args.nccl_test_log:
report = utils.get_nccl_test_report(utils.parse_nccl_test_log(args.nccl_test_log))
if report is not None and global_rank == 0:
write_nccl_test_stats(writers, report)
tsm.init()
if args.use_smp_implementation:
# For our Mem usage fix to TE, this needs to be True
args.use_orig_params = 1
if args.use_synthetic_data and args.validation_freq is not None:
# Overriding validation freq to None as synthetic data
args.validation_freq = None
show_env_vars(0)
if global_rank == 0:
for index, (key, value) in enumerate(sorted(args.__dict__.items()), 1):
logger.info("Arguments [%03d/%03d] %-30s: %s", index, len(args.__dict__), key, value)
logger.info("Transformers version: %s", transformers.__version__)
logger.info("World size = %d: # nodes = %d.", world_size, world_size / 8)
gbs = (
world_size
* args.max_context_width
* args.train_batch_size
/ tsm.state.tensor_parallel_degree
/ tsm.state.context_parallel_degree
)
logger.info("Global batch size in tokens: %10d (%5.2fM).", gbs, gbs / 1024 ** 2)
set_seed(args.seed)
if args.enable_memory_profiling > 0:
memory_status_cpu(tag="Before model creation", writers=writers)
if args.bf16:
dtype = torch.bfloat16
else:
dtype = torch.get_default_dtype()
if finetune_check(args):
from transformers import AutoConfig
# Using config for finetune mode, else uses args to create model
model_config = AutoConfig.from_pretrained(args.hf_pretrained_model_name_or_dir)
# Disable KV cache for HF models
if hasattr(model_config, "use_cache"):
model_config.use_cache = False
else:
model_config = get_model_config(args)
delayed_param_initer = None
with tsm_utils.timeit(True, "Model creation", global_rank):
if args.delayed_param:
model_config.delayed_param = True
if finetune_with_pretrained_weights_check(args) and dist.get_rank() == 0:
# create model with pretrained weights on one rank even if we want to use
# delayed param, param init on other ranks will still be delayed
model = create_model(
args,
model_config=model_config,
dtype=dtype,
pretrained_model_weights=args.hf_pretrained_model_name_or_dir
if finetune_with_pretrained_weights_check(args)
else None,
)
num_params = compute_num_params(model)
else:
with init_empty_weights():
model = create_model(
args,
model_config=model_config,
dtype=dtype,
)
num_params = compute_num_params(model)
if finetune_check(args):
dist.barrier()
else:
model_config.delayed_param = False
model = create_model(
args,
model_config=model_config,
dtype=dtype,
pretrained_model_weights=args.hf_pretrained_model_name_or_dir
if finetune_with_pretrained_weights_check(args) and dist.get_rank() == 0
else None,
)
num_params = compute_num_params(model)
if args.use_smp_implementation:
if args.moe:
from torch.sagemaker.moe.moe_config import MoEConfig
moe_config = MoEConfig(
smp_moe=args.use_smp_implementation > 0,
moe_load_balancing=args.moe_load_balancing,
global_token_shuffle=args.global_token_shuffle > 0,
moe_all_to_all_dispatcher=args.moe_all_to_all_dispatcher > 0,
use_cpu_initialization=finetune_with_pretrained_weights_check(args) and dist.get_rank() == 0
)
else:
moe_config = None
load_state_dict_from_rank0 = finetune_with_pretrained_weights_check(args)
if args.moe and args.delayed_param and (not load_state_dict_from_rank0 or dist.get_rank() != 0):
with init_empty_weights():
model = transform(model, config=moe_config, load_state_dict_from_rank0=load_state_dict_from_rank0, cp_comm_type=args.cp_comm_type)
else:
model = transform(model, config=moe_config, load_state_dict_from_rank0=load_state_dict_from_rank0, cp_comm_type=args.cp_comm_type)
if args.delayed_param:
# param init fn for delayed param creation
if finetune_with_pretrained_weights_check(args):
if dist.get_rank() != 0:
delayed_param_initer = DelayedParamIniter(model)
else:
delayed_param_initer = DelayedParamIniter(model)
assert set(x.dtype for x in model.parameters()) == set(
[torch.float32]
), "Model parameters should be in fp32 for FSDP mixed precision"
if global_rank == 0:
logger.info(
"Created model with total parameters: %d (%.2f B)", num_params, num_params * 1e-9
)
transformer_layer = get_transformer_layer(args.model_type, args.use_smp_implementation,
args.moe)
if args.auto_wrap_policy == "transformer_auto_wrap_policy":
gpt_auto_wrap_policy = functools.partial(
transformer_auto_wrap_policy,
transformer_layer_cls={
transformer_layer,
},
)
elif args.auto_wrap_policy == "size_based_auto_wrap_policy":
gpt_auto_wrap_policy = functools.partial(
size_based_auto_wrap_policy,
)
torch.cuda.set_device(device)
if args.bf16:
# buffer set to fp32 as some models in HF such as llama hard code buffers to fp32
# to be similar with that we set this to fp32
buffer_dtype = torch.float32 if args.use_smp_implementation else dtype
mixed_precision_policy = MixedPrecision(
param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=buffer_dtype
)
else:
mixed_precision_policy = None
if args.enable_memory_profiling > 0:
memory_status_cpu(tag="Before FSDP wrapper", writers=writers)
sharding_strategy = get_sharding_strategy(args.sharding_strategy)
with (
delayed_param_initer.validate_params_and_buffers_inited()
if (delayed_param_initer and not finetune_with_pretrained_weights_check(args))
else nullcontext(),
tsm_utils.timeit(True, "FSDP constructor", global_rank),
):
model = FSDP( # pylint: disable=unexpected-keyword-arg
model,
auto_wrap_policy=gpt_auto_wrap_policy,
mixed_precision=mixed_precision_policy,
sharding_strategy=sharding_strategy,
backward_prefetch=get_backward_fetch_policy(args.backward_fetch_policy),
forward_prefetch=args.forward_prefetch,
limit_all_gathers=args.limit_all_gathers,
device_id=torch.cuda.current_device(),
use_orig_params=args.use_orig_params > 0,
param_init_fn=delayed_param_initer.get_param_init_fn()
if delayed_param_initer
else None,
post_param_init_fn=delayed_param_initer.get_post_param_init_fn()
if delayed_param_initer
else None,
sync_module_states=finetune_with_pretrained_weights_check(args),
)
# Barrier is a workaround to reduce extra memory usage with SMDDP backend
# after the broadcast that happens when we use sync_module_states
# This can be removed once the SMDDP issue is fixed
dist.barrier()
if global_rank == 0:
logger.info("Wrapped model with FSDP")
if args.enable_memory_profiling > 0:
memory_status(tag="After FSDP wrapper", writers=writers)
fp8_recipe = None
if args.fp8==1 and args.use_smp_implementation==1:
fp8_recipe = DelayedScaling(
fp8_format=Format.HYBRID,
amax_history_len=args.fp8_amax_history_len,
amax_compute_algo=args.fp8_amax_compute_algo,
)
if args.activation_checkpointing > 0:
apply_activation_checkpoint(args, model=model)
if tsm.state.sm_activation_offloading > 0:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper
model = offload_wrapper(model)
# Patch RoPE for GPT NEoX where they are created on Host to move them to Device
if args.use_smp_implementation == 0 and args.model_type == "gpt_neox" and args.patch_neox_rope > 0:
patch_neox_rope(model)
param_groups = get_param_groups_by_weight_decay(model)
optimizer = optim.AdamW(
param_groups, betas=(args.beta1, args.beta2), lr=args.lr, weight_decay=args.weight_decay
)
if global_rank == 0:
logger.info("Created optimizer")
lr_scheduler = get_learning_rate_scheduler(optimizer, args)
checkpointing_pg_metadata = (
model.process_group,
get_coordinator_rank(model.process_group),
is_action_rank(global_rank),
)
if args.resume_from_checkpoint:
(
model,
optimizer,
lr_scheduler,
epoch,
total_steps,
start_train_path_index,
resume_from_sequence_number,
val_resume_from_sequence_number,
) = load_checkpoint(
args,
model,
optimizer,
lr_scheduler,
args.resume_from_checkpoint,
sharding_strategy,
checkpointing_pg_metadata,
tensor_parallel_degree=int(tsm.state.tensor_parallel_degree),
expert_parallel_degree=int(tsm.state.expert_parallel_degree),
checkpoint_type=args.checkpoint_type,
)
torch.cuda.empty_cache()
else:
total_steps = 0
epoch = 0
start_train_path_index = 0
resume_from_sequence_number = 0
val_resume_from_sequence_number = 0
train_start_time = time.time()
# total_steps, throughput, loss
total_steps = train(
model,
optimizer,
lr_scheduler,
writers,
model_config,
epoch,
start_train_path_index,
resume_from_sequence_number,
val_resume_from_sequence_number,
num_params,
total_steps,
args,
global_rank,
world_size,
checkpointing_pg_metadata,
fp8_recipe,
)
time_now = time.time()
total_sec = time_now - global_start_time
train_sec = time_now - train_start_time
dist.barrier()
if args.save_final_model:
save_checkpoint(
model,
None,
None,
{"model_config": model_config},
None,
args.model_dir if args.model_dir is not None else args.checkpoint_dir[0],
"" if args.model_dir is not None else "model",
1,
None,
int(tsm.state.tensor_parallel_degree),
int(tsm.state.expert_parallel_degree),
checkpoint_type=CheckpointingMethod.FULL,
)
if global_rank == 0:
train_min = train_sec / 60.0
total_min = total_sec / 60.0
for writer in writers:
runtime = {
"total": total_min,
"train": train_min,
}
writer.add_scalars("Perf/runtime", runtime, total_steps - 1)
logger.info(
"FSDP training finished successfully %fs (%fmin) out of (%fmin).",
train_sec, train_min, total_min
)
dist.destroy_process_group()