def main()

in build_and_train_models/sm-distributed_model_parallel_v2/shared-scripts/train_lib.py [0:0]


def main(args):
    """Main function to train GPT."""
    global_start_time = time.time()

    # Sanity check for args.
    # - Checkpoints.
    ckpt_lens = (
        len(args.checkpoint_dir),
        len(args.checkpoint_freq),
        len(args.num_kept_checkpoints),
    )
    if len(set(ckpt_lens)) != 1:
        raise ValueError(f"Len mismtach for checkpoint dir, freq vs num to keep:  {ckpt_lens}.")

    if args.distributed_backend == "smddp":
        import smdistributed.dataparallel.torch.torch_smddp  # pylint: disable=unused-import

    dist.init_process_group(args.distributed_backend, timeout=datetime.timedelta(seconds=7200))
    global_rank = dist.get_rank()
    device = global_rank % torch.cuda.device_count()
    world_size = dist.get_world_size()
    # Reset all SMP related args if use_smp_implementation=0
    if args.use_smp_implementation == 0:
        tsm.state.tensor_parallel_degree = 1
        tsm.state.expert_parallel_degree = 1
        tsm.state.context_parallel_degree = 1
        args.moe = 0
        args.fp8 = 0
        print_dict = {
            "tensor_parallel_degree": tsm.state.tensor_parallel_degree,
            "expert_parallel_degree": tsm.state.expert_parallel_degree,
            "context_parallel_degree": tsm.state.context_parallel_degree,
            "moe": args.moe,
            "fp8": args.fp8,
        }
        if global_rank == 0:
            logger.warn(f"use_smp_implementation is set to 0. Resetting these params to default values: {print_dict}")

    if args.tensorboard_dir and global_rank == 0:
        from torch.utils.tensorboard import SummaryWriter

        logger.info("Writing metrics for tensorboard to %s.", args.tensorboard_dir)
        writers = tuple(SummaryWriter(log_dir=tb_dir) for tb_dir in args.tensorboard_dir)
        table_str = create_args_table(args.__dict__)
        for writer in writers:
            writer.add_text("Arguments", table_str)
    else:
        writers = ()

    if args.nccl_test_log:
        report = utils.get_nccl_test_report(utils.parse_nccl_test_log(args.nccl_test_log))
        if report is not None and global_rank == 0:
            write_nccl_test_stats(writers, report)

    tsm.init()

    if args.use_smp_implementation:
        # For our Mem usage fix to TE, this needs to be True
        args.use_orig_params = 1

    if args.use_synthetic_data and args.validation_freq is not None:
        # Overriding validation freq to None as synthetic data
        args.validation_freq = None

    show_env_vars(0)

    if global_rank == 0:
        for index, (key, value) in enumerate(sorted(args.__dict__.items()), 1):
            logger.info("Arguments [%03d/%03d] %-30s: %s", index, len(args.__dict__), key, value)
        logger.info("Transformers version: %s", transformers.__version__)
        logger.info("World size = %d: # nodes = %d.", world_size, world_size / 8)

        gbs = (
            world_size
            * args.max_context_width
            * args.train_batch_size
            / tsm.state.tensor_parallel_degree
            / tsm.state.context_parallel_degree
        )
        logger.info("Global batch size in tokens: %10d (%5.2fM).", gbs, gbs / 1024 ** 2)

    set_seed(args.seed)

    if args.enable_memory_profiling > 0:
        memory_status_cpu(tag="Before model creation", writers=writers)

    if args.bf16:
        dtype = torch.bfloat16
    else:
        dtype = torch.get_default_dtype()

    if finetune_check(args):
        from transformers import AutoConfig

        # Using config for finetune mode, else uses args to create model
        model_config = AutoConfig.from_pretrained(args.hf_pretrained_model_name_or_dir)
        # Disable KV cache for HF models
        if hasattr(model_config, "use_cache"):
            model_config.use_cache = False
    else:
        model_config = get_model_config(args)

    delayed_param_initer = None
    with tsm_utils.timeit(True, "Model creation", global_rank):
        if args.delayed_param:
            model_config.delayed_param = True
            if finetune_with_pretrained_weights_check(args) and dist.get_rank() == 0:
                # create model with pretrained weights on one rank even if we want to use
                # delayed param, param init on other ranks will still be delayed
                model = create_model(
                    args,
                    model_config=model_config,
                    dtype=dtype,
                    pretrained_model_weights=args.hf_pretrained_model_name_or_dir
                    if finetune_with_pretrained_weights_check(args)
                    else None,
                )
                num_params = compute_num_params(model)
            else:
                with init_empty_weights():
                    model = create_model(
                        args,
                        model_config=model_config,
                        dtype=dtype,
                    )
                num_params = compute_num_params(model)
            if finetune_check(args):
                dist.barrier()
        else:
            model_config.delayed_param = False
            model = create_model(
                args,
                model_config=model_config,
                dtype=dtype,
                pretrained_model_weights=args.hf_pretrained_model_name_or_dir
                if finetune_with_pretrained_weights_check(args) and dist.get_rank() == 0
                else None,
            )
            num_params = compute_num_params(model)

        if args.use_smp_implementation:
            if args.moe:
                from torch.sagemaker.moe.moe_config import MoEConfig
                moe_config = MoEConfig(
                    smp_moe=args.use_smp_implementation > 0,
                    moe_load_balancing=args.moe_load_balancing,
                    global_token_shuffle=args.global_token_shuffle > 0,
                    moe_all_to_all_dispatcher=args.moe_all_to_all_dispatcher > 0,
                    use_cpu_initialization=finetune_with_pretrained_weights_check(args) and dist.get_rank() == 0
                )
            else:
                moe_config = None
            load_state_dict_from_rank0 = finetune_with_pretrained_weights_check(args)
            if args.moe and args.delayed_param and (not load_state_dict_from_rank0 or dist.get_rank() != 0):
                with init_empty_weights():
                    model = transform(model, config=moe_config, load_state_dict_from_rank0=load_state_dict_from_rank0, cp_comm_type=args.cp_comm_type)
            else:
                model = transform(model, config=moe_config, load_state_dict_from_rank0=load_state_dict_from_rank0, cp_comm_type=args.cp_comm_type)

        if args.delayed_param:
            # param init fn for delayed param creation
            if finetune_with_pretrained_weights_check(args):
                if dist.get_rank() != 0:
                    delayed_param_initer = DelayedParamIniter(model)
            else:
                delayed_param_initer = DelayedParamIniter(model)

    assert set(x.dtype for x in model.parameters()) == set(
        [torch.float32]
    ), "Model parameters should be in fp32 for FSDP mixed precision"

    if global_rank == 0:
        logger.info(
            "Created model with total parameters: %d (%.2f B)", num_params, num_params * 1e-9
        )

    transformer_layer = get_transformer_layer(args.model_type, args.use_smp_implementation,
                                              args.moe)

    if args.auto_wrap_policy == "transformer_auto_wrap_policy":
        gpt_auto_wrap_policy = functools.partial(
            transformer_auto_wrap_policy,
            transformer_layer_cls={
                transformer_layer,
            },
        )
    elif args.auto_wrap_policy == "size_based_auto_wrap_policy":
        gpt_auto_wrap_policy = functools.partial(
            size_based_auto_wrap_policy,
        )

    torch.cuda.set_device(device)
    if args.bf16:
        # buffer set to fp32 as some models in HF such as llama hard code buffers to fp32
        # to be similar with that we set this to fp32
        buffer_dtype = torch.float32 if args.use_smp_implementation else dtype
        mixed_precision_policy = MixedPrecision(
            param_dtype=dtype, reduce_dtype=dtype, buffer_dtype=buffer_dtype
        )
    else:
        mixed_precision_policy = None

    if args.enable_memory_profiling > 0:
        memory_status_cpu(tag="Before FSDP wrapper", writers=writers)

    sharding_strategy = get_sharding_strategy(args.sharding_strategy)

    with (
        delayed_param_initer.validate_params_and_buffers_inited()
        if (delayed_param_initer and not finetune_with_pretrained_weights_check(args))
        else nullcontext(),
        tsm_utils.timeit(True, "FSDP constructor", global_rank),
    ):
        model = FSDP(  # pylint: disable=unexpected-keyword-arg
            model,
            auto_wrap_policy=gpt_auto_wrap_policy,
            mixed_precision=mixed_precision_policy,
            sharding_strategy=sharding_strategy,
            backward_prefetch=get_backward_fetch_policy(args.backward_fetch_policy),
            forward_prefetch=args.forward_prefetch,
            limit_all_gathers=args.limit_all_gathers,
            device_id=torch.cuda.current_device(),
            use_orig_params=args.use_orig_params > 0,
            param_init_fn=delayed_param_initer.get_param_init_fn()
            if delayed_param_initer
            else None,
            post_param_init_fn=delayed_param_initer.get_post_param_init_fn()
            if delayed_param_initer
            else None,
            sync_module_states=finetune_with_pretrained_weights_check(args),
        )
    # Barrier is a workaround to reduce extra memory usage with SMDDP backend
    # after the broadcast that happens when we use sync_module_states
    # This can be removed once the SMDDP issue is fixed
    dist.barrier()

    if global_rank == 0:
        logger.info("Wrapped model with FSDP")

    if args.enable_memory_profiling > 0:
        memory_status(tag="After FSDP wrapper", writers=writers)

    fp8_recipe = None
    if args.fp8==1 and args.use_smp_implementation==1:
        fp8_recipe = DelayedScaling(
            fp8_format=Format.HYBRID,
            amax_history_len=args.fp8_amax_history_len,
            amax_compute_algo=args.fp8_amax_compute_algo,
        )

    if args.activation_checkpointing > 0:
        apply_activation_checkpoint(args, model=model)

    if tsm.state.sm_activation_offloading > 0:
        from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import offload_wrapper

        model = offload_wrapper(model)

        # Patch RoPE for GPT NEoX where they are created on Host to move them to Device
        if args.use_smp_implementation == 0 and args.model_type == "gpt_neox" and args.patch_neox_rope > 0:
            patch_neox_rope(model)

    param_groups = get_param_groups_by_weight_decay(model)

    optimizer = optim.AdamW(
        param_groups, betas=(args.beta1, args.beta2), lr=args.lr, weight_decay=args.weight_decay
    )

    if global_rank == 0:
        logger.info("Created optimizer")

    lr_scheduler = get_learning_rate_scheduler(optimizer, args)

    checkpointing_pg_metadata = (
        model.process_group,
        get_coordinator_rank(model.process_group),
        is_action_rank(global_rank),
    )

    if args.resume_from_checkpoint:
        (
            model,
            optimizer,
            lr_scheduler,
            epoch,
            total_steps,
            start_train_path_index,
            resume_from_sequence_number,
            val_resume_from_sequence_number,
        ) = load_checkpoint(
            args,
            model,
            optimizer,
            lr_scheduler,
            args.resume_from_checkpoint,
            sharding_strategy,
            checkpointing_pg_metadata,
            tensor_parallel_degree=int(tsm.state.tensor_parallel_degree),
            expert_parallel_degree=int(tsm.state.expert_parallel_degree),
            checkpoint_type=args.checkpoint_type,
        )
        torch.cuda.empty_cache()

    else:
        total_steps = 0
        epoch = 0
        start_train_path_index = 0
        resume_from_sequence_number = 0
        val_resume_from_sequence_number = 0

    train_start_time = time.time()
    # total_steps, throughput, loss
    total_steps = train(
        model,
        optimizer,
        lr_scheduler,
        writers,
        model_config,
        epoch,
        start_train_path_index,
        resume_from_sequence_number,
        val_resume_from_sequence_number,
        num_params,
        total_steps,
        args,
        global_rank,
        world_size,
        checkpointing_pg_metadata,
        fp8_recipe,
    )
    time_now = time.time()
    total_sec = time_now - global_start_time
    train_sec = time_now - train_start_time

    dist.barrier()

    if args.save_final_model:
        save_checkpoint(
            model,
            None,
            None,
            {"model_config": model_config},
            None,
            args.model_dir if args.model_dir is not None else args.checkpoint_dir[0],
            "" if args.model_dir is not None else "model",
            1,
            None,
            int(tsm.state.tensor_parallel_degree),
            int(tsm.state.expert_parallel_degree),
            checkpoint_type=CheckpointingMethod.FULL,
        )

    if global_rank == 0:
        train_min = train_sec / 60.0
        total_min = total_sec / 60.0

        for writer in writers:
            runtime = {
                "total": total_min,
                "train": train_min,
            }
            writer.add_scalars("Perf/runtime", runtime, total_steps - 1)

        logger.info(
            "FSDP training finished successfully %fs (%fmin) out of (%fmin).",
            train_sec, train_min, total_min
        )

    dist.destroy_process_group()