def get_transformer_layer()

in src/hyperpod_nemo_adapter/utils/fsdp_utils.py [0:0]


def get_transformer_layer(model_type="gpt2", use_smp_model=False, moe=False, peft_type=None):
    """Get transformer layer."""
    if use_smp_model:
        # For pt-2.1-tsm-2.1 releases and below,
        # We can't checkpoint our transformer.TransformerLayer class as it takes a tuple as input,
        # so we checkpoint the te.TETransformerLayer directly instead.
        # In later versions, we patch TransformerEngine activation checkpointing logic in our containers
        # with some missing native PyTorch checkpoint logic and bug fixes to resolve this.
        # PT ref: https://github.com/pytorch/pytorch/blob/v2.2.0/torch/utils/checkpoint.py#L307-L319
        # TE ref: https://github.com/NVIDIA/TransformerEngine/blob/v1.2.1/transformer_engine/pytorch/distributed.py#L272
        if LooseVersion(torch.__version__) >= LooseVersion("2.2.0"):
            from torch.sagemaker.tensor_parallel.transformer import TransformerLayer

            transformer_layer = TransformerLayer
        else:
            from torch.sagemaker.tensor_parallel.transformer import TETransformerLayer

            transformer_layer = TETransformerLayer
    elif "llama_v2" in model_type or "llama_v3" in model_type:
        from transformers.models.llama.modeling_llama import LlamaDecoderLayer

        transformer_layer = LlamaDecoderLayer
    elif "llama_v4" in model_type:
        from transformers.models.llama4.modeling_llama4 import Llama4TextDecoderLayer

        transformer_layer = Llama4TextDecoderLayer
    elif model_type == "mistral":
        from transformers.models.mistral.modeling_mistral import MistralDecoderLayer

        transformer_layer = MistralDecoderLayer
    elif model_type == "mixtral":
        from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer

        transformer_layer = MixtralDecoderLayer
    elif "qwen_v2" in model_type:
        from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer

        transformer_layer = Qwen2DecoderLayer
    elif "deepseek_r1" in model_type:
        from hyperpod_nemo_adapter.collections.model.nlp.custom_models.modeling_deepseek import (
            DeepseekV3DecoderLayer,
            DeepseekV3MLP,
        )

        if (
            peft_type is not None and peft_type == "qlora_4bit"
        ):  # for QLoRA finetuning we dont need to wrap the MLP layer
            transformer_layer = DeepseekV3DecoderLayer
        else:  # for all other finetuning strategies also wrap MLP layers
            transformer_layer = (DeepseekV3DecoderLayer, DeepseekV3MLP)

    if transformer_layer == None:
        raise Exception(f"transformer_layer for model type {model_type} not defined.")

    return transformer_layer