in src/hyperpod_nemo_adapter/utils/train_utils.py [0:0]
def apply_activation_checkpoint_moe(model=None, checkpoint_attn=True, checkpoint_moe=True):
"""
Experimental checkpointing with multiple checkpoint wrappers.
Use TE checkpoint for attention, and megatron/native checkpoint for MoE layer.
"""
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
CheckpointImpl,
apply_activation_checkpointing,
checkpoint_wrapper,
)
checkpoint_impl = CheckpointImpl.NO_REENTRANT
if checkpoint_attn:
import torch.sagemaker as tsm
import transformer_engine
from transformer_engine.pytorch.attention import MultiheadAttention
check_fn_attn = lambda submodule: isinstance( # pylint: disable=unnecessary-lambda-assignment
submodule, MultiheadAttention
)
checkpoint_fn_attn = functools.partial(
transformer_engine.pytorch.checkpoint,
distribute_saved_activations=False,
get_rng_state_tracker=tsm.state.get_rng_state_tracker,
tp_group=tsm.state.tp_process_group,
use_reentrant=False,
)
# flash attn v2 does not work with no_reentrant
# our activation offloading for 2.0 also does not work with no_reentrant
entrant_wrapper_attn = functools.partial(
checkpoint_wrapper, checkpoint_impl=checkpoint_impl, checkpoint_fn=checkpoint_fn_attn
)
apply_activation_checkpointing(model, checkpoint_wrapper_fn=entrant_wrapper_attn, check_fn=check_fn_attn)
if checkpoint_moe:
from torch.sagemaker.moe.moe_layer import MoELayer
check_fn_moe = lambda submodule: isinstance( # pylint: disable=unnecessary-lambda-assignment
submodule, MoELayer
)
checkpoint_fn_moe = None
entrant_wrapper_moe = functools.partial(
checkpoint_wrapper, checkpoint_impl=checkpoint_impl, checkpoint_fn=checkpoint_fn_moe
)
apply_activation_checkpointing(model, checkpoint_wrapper_fn=entrant_wrapper_moe, check_fn=check_fn_moe)