def main()

in archived/smp-gpt-sharded-data-parallel/train.py [0:0]


def main():  # pylint: disable=too-many-branches,too-many-locals,too-many-statements
    """Main function to train GPT."""
    args = parse_args()

    if args.partition_assignment != "" and args.manual_partition == 0:
        logging.warning("Partition_assignment is set, enable manual_partition.")
        args.manual_partition = 1

    # any value here is overriden by the config set in notebook when launching the sagemaker job
    smp_config = {
        "ddp": True,
        "tensor_parallel_degree": args.tensor_parallel_degree,
        "pipeline_parallel_degree": args.pipeline_parallel_degree,
        "microbatches": args.microbatches,
        "shard_optimizer_state": args.shard_optimizer_state > 0,
        "prescaled_batch": args.prescaled_batch > 0,
        "fp16": args.fp16 > 0,
        "bf16": args.bf16 > 0,
        "offload_activations": args.offload_activations > 0,
        "delayed_parameter_initialization": args.delayed_param > 0,
        "optimize": args.optimize,
        "placement_strategy": args.placement_strategy,
        "activation_loading_horizon": args.activation_loading_horizon,
        "skip_tracing": args.skip_tracing > 0,
        "auto_partition": not args.manual_partition,
        "default_partition": 0,
        "static_mode": args.static_mode > 0,
        "fast_mode": args.fast_mode > 0,
        "sharded_data_parallel_degree": args.sharded_data_parallel_degree,
        "ddp_dist_backend": args.ddp_dist_backend,
        "sdp_hierarchical_allgather": False,
        "sdp_gradient_clipping": args.grad_clip,
    }
    if args.active_microbatches is not None:
        smp_config["active_microbatches"] = args.active_microbatches
    if args.log_param_norms and args.use_distributed_transformer == 1:
        logging.warning(
            "Script currently doesn't support logging param norms when using distributed transformer, disabling log_param_norms"  # pylint: disable=line-too-long
        )
    smp.init(smp_config)

    _show_env_vars(0)

    if smp.rank() == 0:
        logging.info("Arguments: %s", args.__dict__)
        logging.info("Transformers version: %s", transformers.__version__)
        logging.info(
            "smdistributed.modelparallel version: %s", smdistributed.modelparallel.__version__
        )
        logging.info("smdistributed config: %s", smp_config)

    if args.save_final_full_model and smp.rank() == 0:
        logging.warning(
            "Note that save_final_full_model only saves the final model at the end "
            "of all steps. It does not save optimizer state. Optimizer state is only "
            "saved with partial models which are saved at checkpointing_freq during "
            "training. If you want to restart training you need partial checkpoints."
        )

    if args.partition_assignment != "":
        partition_assignment = args.partition_assignment.split(",")
        msg = (
            f"partition_assignment must have the same size as pipeline parallel degree, "
            f"but getting {len(partition_assignment)} vs {smp.pp_size()}"
        )
        logging.fatal("Will fail with: %s.", msg)
        raise AssertionError(msg)

    model_config, args = model_config_lib.get_model_config_from_args(
        args.model_type, args.model_name, args, log=(smp.rank() == 0)
    )

    # the following improves start-up time by skipping proper initialization
    # of weights in the original model. this is not a problem because DistributedModel
    # will override those weights anyway when we use distributed transformer.
    if args.use_distributed_transformer > 0:
        from transformers.modeling_utils import (  # pylint: disable=import-error,import-outside-toplevel
            PreTrainedModel,
        )

        PreTrainedModel.init_weights = lambda x: None

    set_seed(args.seed)

    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="before model creation")

    if args.fp16 and args.bf16:
        raise ValueError("FP16 and BF16 cannot be simultaneously enabled.")

    if args.fp16:
        dtype = torch.float16  # pylint: disable=no-member
    elif args.bf16:
        dtype = torch.bfloat16  # pylint: disable=no-member
    else:
        dtype = torch.get_default_dtype()  # pylint: disable=no-member

    if args.fine_tune > 0 and args.delayed_param > 0 and smp.rank() == 0:
        pretrained_model = AutoModelForCausalLM.from_pretrained(
            args.model_name or args.model_dir
        )
        model_state_dict = pretrained_model.state_dict()
        path = os.path.join(args.model_dir, "fullmodel.pt")
        torch.save(model_state_dict, path)
    smp.barrier()

    # About zero_init:
    # we only want to init with zero for actual model for training,
    # in disttf case it's used in DistModel wrapper. for others we don't need to set zero init
    # This is needed only to param_id_to_offset
    with smp.model_creation(
        tensor_parallelism=smp.tp_size() > 1 or args.use_distributed_transformer > 0,
        zero_init=args.use_distributed_transformer == 0,
        dtype=dtype,
        distribute_embedding=args.sharded_data_parallel_degree > 1 and smp.tp_size() > 1,
        use_alibi=args.alibi > 0,
        attention_in_fp32=args.attention_in_fp32 > 0,
        fp32_residual_addition=args.residual_addition_in_fp32 > 0,
        query_key_layer_scaling=args.query_key_layer_scaling > 0 and args.bf16 < 1,
        fused_softmax=args.fused_softmax > 0,
        fused_dropout=args.fused_dropout > 0,
        fused_bias_gelu=args.fused_bias_gelu > 0,
        flash_attention=args.flash_attention > 0,
    ):
        if args.fine_tune > 0 and args.delayed_param == 0:
            model = AutoModelForCausalLM.from_pretrained(
                args.model_name or args.model_dir
            )
        else:
            model = AutoModelForCausalLM.from_config(model_config)

    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="after model creation")

    # smdistributed: Set the device to the GPU ID used by the current process.
    # Input tensors should be transferred to this device.
    torch.cuda.set_device(smp.local_rank())

    if not args.same_seed:
        # Set seed by tp_rank to prevent weights from being the same on different tp_ranks
        set_seed(args.seed + smp.tp_rank())

    # smdistributed: Use the DistributedModel container to provide the model
    # to be partitioned across different ranks. For the rest of the script,
    # the returned DistributedModel object should be used in place of
    # the model provided for DistributedModel class instantiation.
    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="before dist model creation")

    model = smp.DistributedModel(
        model, trace_device="gpu", backward_passes_per_step=args.gradient_accumulation
    )

    if args.enable_memory_profiling > 0:
        memory_status_cpu(msg="after dist model creation")
    m = model.get_module()  # pylint: disable=invalid-name

    num_params = compute_num_params(m)
    if smp.rank() == 0:
        logging.info("# total parameters: %s", num_params)

    if args.use_distributed_transformer > 0:
        transformer_layers = m.transformer.seq_layers
    else:
        if args.model_type in ["gpt2", "bloom"]:
            transformer_layers = m.transformer.h
        elif args.model_type == "gpt_neox":
            transformer_layers = m.gpt_neox.layers

    if args.manual_partition:
        logging.debug("Manual partition enabled")
        if args.partition_assignment != "":
            get_num_layers = lambda x: int(  # pylint: disable=unnecessary-lambda-assignment
                partition_assignment[x]
            )
            total_layers = sum(get_num_layers(pp_rank) for pp_rank in range(smp.pp_size()))

            msg = (
                f"partition_assignment must have the same total transformer layers as model, "
                f"but getting {total_layers} vs {args.num_layers}"
            )
            logging.fatal("Will fail with: %s.", msg)
            raise AssertionError(msg)

        # evenly distribute layers across all partitions
        div, rem = divmod(args.num_layers, smp.pp_size())
        get_num_layers = lambda x: (  # pylint: disable=unnecessary-lambda-assignment
            div + 1 if x >= smp.pp_size() - rem else div
        )

        assignments = []
        # (TODO) This is required for 175B otherwise a hang for partition "8,17,17,18,18,18"
        # Need further investigation
        # for pp_rank in reversed(range(smp.pp_size())):
        for pp_rank in range(smp.pp_size()):
            nl = get_num_layers(pp_rank)  # pylint: disable=invalid-name
            logging.debug("%s layers assigned to partition %d", nl, pp_rank)
            assignments += [pp_rank for _ in range(nl)]

        for i, c in enumerate(transformer_layers.children()):  # pylint: disable=invalid-name
            smp.set_partition(c, assignments[i])

    param_groups = get_param_groups_by_weight_decay(m)

    if args.use_adamw > 0:
        optimizer = optim.AdamW(
            param_groups, betas=(args.beta1, args.beta2), lr=args.lr, weight_decay=args.weight_decay
        )
    else:
        optimizer = optim.Adam(
            param_groups, betas=(args.beta1, args.beta2), lr=args.lr, weight_decay=args.weight_decay
        )

    if args.activation_checkpointing:  # pylint: disable=too-many-nested-blocks
        if args.use_distributed_transformer or smp.tp_size() > 1:
            if args.checkpoint_sublayers:
                for c in transformer_layers.children():  # pylint: disable=invalid-name
                    smp.set_activation_checkpointing(c.attention)
                    smp.set_activation_checkpointing(c.output)
            else:
                smp.set_activation_checkpointing(
                    transformer_layers, strategy=args.activation_strategy
                )
        else:
            for c in transformer_layers.children():  # pylint: disable=invalid-name
                if args.checkpoint_sublayers:
                    if args.model_type == "gpt2":
                        smp.set_activation_checkpointing(c.attn)
                        smp.set_activation_checkpointing(c.mlp)
                    elif args.model_type in ["gpt_neox", "bloom"]:
                        if args.model_type == "gpt_neox":
                            smp.set_activation_checkpointing(c.attention)
                        elif args.model_type == "bloom":
                            smp.set_activation_checkpointing(c.self_attention)
                        smp.set_activation_checkpointing(c.input_layernorm)
                        smp.set_activation_checkpointing(c.post_attention_layernorm)
                        smp.set_activation_checkpointing(c.mlp)
                else:
                    smp.set_activation_checkpointing(c)

    if args.sharded_data_parallel_degree > 1 and args.use_distributed_transformer == 0:
        param_id_to_offset = build_param_id_to_offset(param_groups)

    optimizer = smp.DistributedOptimizer(
        optimizer,
        static_loss_scale=None,
        dynamic_loss_scale=True,
        dynamic_loss_args={"scale_window": 1000, "min_scale": 1, "delayed_shift": 2},
    )

    if args.fine_tune > 0 and args.delayed_param > 0:
        smp.resume_from_checkpoint(args.model_dir, tag="fullmodel.pt", partial=False)

    if args.sharded_data_parallel_degree > 1 and args.use_distributed_transformer == 0:
        param_id_to_buffer = build_param_id_to_buffer(optimizer, param_id_to_offset)
    else:
        param_id_to_buffer = None

    lr_scheduler = get_learning_rate_scheduler(optimizer, args)

    if args.enable_memory_profiling > 0:
        model.register_post_partition_hook(
            lambda model, optimizer: memory_status(msg="After partition")
        )

    # load after wrapping model and optimizer with smp Distributed...
    if args.load_full or args.load_partial:
        if args.load_partial and args.load_full:
            logging.info(
                "Since both --load_partial and --load_full set, will try to load from full "
                "checkpoint. If the intention is to load from partial checkpoint, please don't set "
                "--load_full"
            )
        partial = not args.load_full
        path = args.checkpoint_dir if partial else args.model_dir
        tag = None if partial else "fullmodel.pt"
        user_content = smp.resume_from_checkpoint(path, tag=tag, partial=partial)
        total_steps = user_content["total_steps"] if partial else 0
        start_train_path_index = user_content.get("start_train_path_index", 0)
        start_batch_index = user_content.get("start_batch_index", 0)
        if "lr_scheduler" in user_content:
            lr_scheduler.load_state_dict(user_content["lr_scheduler"])
    else:
        total_steps = 0
        start_train_path_index = 0
        start_batch_index = 0

    # Add emty cache to clear memory when loaded with partial checkpointing
    # for SDPTP and GPT NeoX
    torch.cuda.empty_cache()

    start = time.time()
    total_steps, throughput, loss = train(
        model,
        optimizer,
        lr_scheduler,
        model_config,
        start_train_path_index,
        start_batch_index,
        num_params,
        total_steps,
        args,
        param_id_to_buffer,
    )
    time_to_train = time.time() - start
    if args.ci:
        logging.info("[SMP_METRIC]__GPT2__Time_to_train__%s", time_to_train)
        logging.info("[SMP_METRIC]__GPT2__samples/second__%s", throughput)
        logging.info("[SMP_METRIC]__GPT2__Loss__%s", loss)
        if not args.load_partial and not args.load_full:
            if time_to_train >= args.time_to_train:
                msg = f"Time to train ({time_to_train}) >= threshold ({args.time_to_train})"
                logging.fatal("Will fail with: %s.", msg)
                raise AssertionError(msg)

            if throughput <= args.throughput:
                msg = f"Throughput ({throughput}) >= threshold ({args.throughput})"
                logging.fatal("Will fail with: %s.", msg)
                raise AssertionError(msg)

            if args.loss and loss >= args.loss:
                msg = f"Loss ({loss}) >= threshold ({args.loss})"
                logging.fatal("Will fail with: %s.", msg)
                raise AssertionError(msg)

    if args.save_final_full_model:
        # saves full model at the end
        user_content = {
            "cli_args": args.__dict__,
            "num_params": num_params,
            "total_steps": total_steps,
            "model_config": model_config,
        }
        smp.save_checkpoint(
            args.model_dir,
            tag="fullmodel.pt",
            partial=False,
            model=model,
            user_content=user_content,
        )

    smp.barrier()
    if smp.rank() == 0:
        logging.info("SMP training finished successfully")