in src/hyperpod_nemo_adapter/utils/fsdp_utils.py [0:0]
def get_transformer_layer(model_type="gpt2", use_smp_model=False, moe=False, peft_type=None):
"""Get transformer layer."""
if use_smp_model:
# For pt-2.1-tsm-2.1 releases and below,
# We can't checkpoint our transformer.TransformerLayer class as it takes a tuple as input,
# so we checkpoint the te.TETransformerLayer directly instead.
# In later versions, we patch TransformerEngine activation checkpointing logic in our containers
# with some missing native PyTorch checkpoint logic and bug fixes to resolve this.
# PT ref: https://github.com/pytorch/pytorch/blob/v2.2.0/torch/utils/checkpoint.py#L307-L319
# TE ref: https://github.com/NVIDIA/TransformerEngine/blob/v1.2.1/transformer_engine/pytorch/distributed.py#L272
if LooseVersion(torch.__version__) >= LooseVersion("2.2.0"):
from torch.sagemaker.tensor_parallel.transformer import TransformerLayer
transformer_layer = TransformerLayer
else:
from torch.sagemaker.tensor_parallel.transformer import TETransformerLayer
transformer_layer = TETransformerLayer
elif "llama_v2" in model_type or "llama_v3" in model_type:
from transformers.models.llama.modeling_llama import LlamaDecoderLayer
transformer_layer = LlamaDecoderLayer
elif "llama_v4" in model_type:
from transformers.models.llama4.modeling_llama4 import Llama4TextDecoderLayer
transformer_layer = Llama4TextDecoderLayer
elif model_type == "mistral":
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer
transformer_layer = MistralDecoderLayer
elif model_type == "mixtral":
from transformers.models.mixtral.modeling_mixtral import MixtralDecoderLayer
transformer_layer = MixtralDecoderLayer
elif "qwen_v2" in model_type:
from transformers.models.qwen2.modeling_qwen2 import Qwen2DecoderLayer
transformer_layer = Qwen2DecoderLayer
elif "deepseek_r1" in model_type:
from hyperpod_nemo_adapter.collections.model.nlp.custom_models.modeling_deepseek import (
DeepseekV3DecoderLayer,
DeepseekV3MLP,
)
if (
peft_type is not None and peft_type == "qlora_4bit"
): # for QLoRA finetuning we dont need to wrap the MLP layer
transformer_layer = DeepseekV3DecoderLayer
else: # for all other finetuning strategies also wrap MLP layers
transformer_layer = (DeepseekV3DecoderLayer, DeepseekV3MLP)
if transformer_layer == None:
raise Exception(f"transformer_layer for model type {model_type} not defined.")
return transformer_layer