def apply_activation_checkpointing()

in optimum/neuron/accelerate/utils/misc.py [0:0]


def apply_activation_checkpointing(model: Union["PreTrainedModel", "NxDPPModel", "NeuronPeftModel"]):
    from neuronx_distributed.pipeline import NxDPPModel
    from neuronx_distributed.utils.activation_checkpoint import (
        apply_activation_checkpointing as nxd_apply_activation_checkpointing,
    )

    from ...peft.peft_model import NeuronPeftModel

    if isinstance(model, NeuronPeftModel):
        model._prepare_model_for_gradient_checkpointing(model.get_base_model())

    if isinstance(model, NxDPPModel):
        modules = itertools.chain(module.modules() for module in model.local_stage_modules)
    else:
        modules = model.modules()

    gradient_checkpointing_modules = set()
    for module in modules:
        if isinstance(module, torch.nn.ModuleList):
            for mod in module:
                # TODO: @michaelbenayoun. Need to find a better way to identify the blocks to apply gradient
                # checkpointing to.
                if "Layer" in mod.__class__.__name__ or "Block" in mod.__class__.__name__:
                    gradient_checkpointing_modules.add(mod)

    def check_fn(m: torch.nn.Module) -> bool:
        return m in gradient_checkpointing_modules

    if gradient_checkpointing_modules:
        nxd_apply_activation_checkpointing(model, check_fn=check_fn)