def apply_activation_checkpoint_moe()

in src/hyperpod_nemo_adapter/utils/train_utils.py [0:0]


def apply_activation_checkpoint_moe(model=None, checkpoint_attn=True, checkpoint_moe=True):
    """
    Experimental checkpointing with multiple checkpoint wrappers.
    Use TE checkpoint for attention, and megatron/native checkpoint for MoE layer.
    """
    from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
        CheckpointImpl,
        apply_activation_checkpointing,
        checkpoint_wrapper,
    )

    checkpoint_impl = CheckpointImpl.NO_REENTRANT

    if checkpoint_attn:
        import torch.sagemaker as tsm
        import transformer_engine
        from transformer_engine.pytorch.attention import MultiheadAttention

        check_fn_attn = lambda submodule: isinstance(  # pylint: disable=unnecessary-lambda-assignment
            submodule, MultiheadAttention
        )
        checkpoint_fn_attn = functools.partial(
            transformer_engine.pytorch.checkpoint,
            distribute_saved_activations=False,
            get_rng_state_tracker=tsm.state.get_rng_state_tracker,
            tp_group=tsm.state.tp_process_group,
            use_reentrant=False,
        )
        # flash attn v2 does not work with no_reentrant
        # our activation offloading for 2.0 also does not work with no_reentrant
        entrant_wrapper_attn = functools.partial(
            checkpoint_wrapper, checkpoint_impl=checkpoint_impl, checkpoint_fn=checkpoint_fn_attn
        )
        apply_activation_checkpointing(model, checkpoint_wrapper_fn=entrant_wrapper_attn, check_fn=check_fn_attn)

    if checkpoint_moe:
        from torch.sagemaker.moe.moe_layer import MoELayer

        check_fn_moe = lambda submodule: isinstance(  # pylint: disable=unnecessary-lambda-assignment
            submodule, MoELayer
        )
        checkpoint_fn_moe = None
        entrant_wrapper_moe = functools.partial(
            checkpoint_wrapper, checkpoint_impl=checkpoint_impl, checkpoint_fn=checkpoint_fn_moe
        )
        apply_activation_checkpointing(model, checkpoint_wrapper_fn=entrant_wrapper_moe, check_fn=check_fn_moe)